In [1]:
import duckdb
conn = duckdb.connect("embeddings.db")

In [2]:
from FlagEmbedding import BGEM3FlagModel
import torch

device = "cpu"
# use a GPU if available to speed up the embedding computation
if torch.cuda.is_available(): device = "cuda" # Nvidia GPU
elif torch.backends.mps.is_available(): device = "mps" # Apple silicon GPU

print(f'Using device: {device}')

model = BGEM3FlagModel('/Users/yamingdeng/AI/models/bge-m3', use_fp16=True, device=device)

  Referenced from: <CFED5F8E-EC3F-36FD-AAA3-2C6C7F8D3DD9> /Users/yamingdeng/miniconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


Using device: mps


In [3]:
queries = ["What is BGE M3?", "What is DuckDB?"]
documents = [
    "BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
    "DuckDB is a fast in-process analytical database. It supports a feature-rich SQL dialect complemented with deep integrations into client APIs",
]

query_embeddings = model.encode(queries)["dense_vecs"]
document_embeddings = model.encode(documents)["dense_vecs"]

similarity = query_embeddings @ document_embeddings.T
similarity

array([[0.626 , 0.1918],
       [0.3362, 0.732 ]], dtype=float16)

In [3]:
from duckdb.typing import VARCHAR
import pyarrow as pa
import numpy as np

def embed(sentence: str) -> np.ndarray:
    return model.encode(sentence)['dense_vecs']

conn.create_function("embed", embed, [VARCHAR], 'FLOAT[1024]')

<duckdb.duckdb.DuckDBPyConnection at 0x1041dfcb0>

In [None]:
sql = """
SELECT embed('Who was the first human on the moon?') AS query_embedding;
"""

conn.execute(sql).fetchall()

In [4]:
sql = """
INSTALL vss;
LOAD vss;
"""
conn.execute(sql)

<duckdb.duckdb.DuckDBPyConnection at 0x1041dfcb0>

In [5]:
sql = """
SET GLOBAL hnsw_enable_experimental_persistence = true;
"""
conn.execute(sql)

<duckdb.duckdb.DuckDBPyConnection at 0x1041dfcb0>

In [6]:
sql = """
CREATE TABLE embeddings(
     doc_id VARCHAR,
     embedding FLOAT[1024]
);
"""
conn.execute(sql)

<duckdb.duckdb.DuckDBPyConnection at 0x1041dfcb0>

In [9]:
sql = """
CREATE INDEX ip_idx ON embeddings USING HNSW (embedding)
WITH (metric = 'ip');
"""
conn.execute(sql)

<duckdb.duckdb.DuckDBPyConnection at 0x1080c2170>

In [10]:
sql = """
CREATE TABLE data_01 AS
(select SKU, description, price from generated_data_01.csv)
"""

conn.sql(sql)
#0.0s

In [7]:
sql = """
insert into embeddings(doc_id, embedding) 
SELECT SKU, embed(description) as embedding FROM generated_data_01.csv
"""
conn.sql(sql)
# 3m 27.7s

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [12]:
sql = """
WITH top_k AS (
    FROM embeddings 
    SELECT *
    ORDER BY array_inner_product(embedding, embed($q))
    LIMIT 5)
FROM top_k JOIN data_02 ON (data_02.SKU = top_k.doc_id)
SELECT data_02.SKU, description, array_inner_product(embedding, embed($q)) AS similarity
ORDER BY similarity DESC
"""

conn.sql(sql, params={'q': 'Cross-group 12 executive parallelism'}).fetchall()
# 0.5s

[('886-35-2858', 'Cross-group executive parallelism', 0.9070870280265808),
 ('679-47-9529', 'Cross-group executive array', 0.7563570141792297),
 ('292-50-7938', 'Cross-group executive array', 0.7563570141792297),
 ('438-90-7050', 'Persistent executive parallelism', 0.7523760199546814),
 ('273-75-3630', 'Cross-group reciprocal parallelism', 0.7474920153617859)]

In [8]:
sql = """
CREATE TABLE embeddings2(
     doc_id VARCHAR,
     embedding FLOAT[1024]
);
"""
conn.execute(sql)

<duckdb.duckdb.DuckDBPyConnection at 0x1041dfcb0>

In [9]:
sql = """
insert into embeddings2(doc_id, embedding) 
SELECT SKU, embed(description) as embedding FROM generated_data_02.csv
"""
conn.sql(sql)
# 4m 19.3s

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [10]:
sql = """
create table similarity_matrix AS
select e1.doc_id as doc_id1, e2.doc_id as doc_id2, array_inner_product(e1.embedding, e2.embedding) as ratio
from embeddings e1, embeddings2 e2
where array_inner_product(e1.embedding, e2.embedding) >= 0.80
"""

conn.sql(sql)
# 1m 37.3s

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [14]:
sql = """
COPY (
    select * from (
        select 
            d1.sku as sku1, d1.description as description1, d2.sku as sku2, d2.description as description2, sm.ratio,
            row_number() OVER (PARTITION BY sm.doc_id1 ORDER BY sm.ratio desc) as rn
        from similarity_matrix sm, 
            (SELECT SKU, description FROM generated_data_01.csv) d1, 
            (SELECT SKU, description FROM generated_data_02.csv) d2
        where sm.doc_id1 = d1.SKU and sm.doc_id2 = d2.SKU
    ) where rn = 1
) TO 'similarity_matrix_output.csv' (HEADER, DELIMITER ',');
"""

conn.sql(sql)
# 0.0s