# An intro to Matryoskha Embeddings

In [None]:
%pip install ranx duckdb

In [1]:
import duckdb

In [2]:
con = duckdb.connect("olympics.duckdb")

In [3]:
con.sql("DESCRIBE olympics")


┌─────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│ column_name │ column_type │  null   │   key   │ default │  extra  │
│   varchar   │   varchar   │ varchar │ varchar │ varchar │ varchar │
├─────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ index       │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ text        │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ url         │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ title       │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
└─────────────┴─────────────┴─────────┴─────────┴─────────┴─────────┘

In [4]:
con.sql("SELECT * FROM olympics LIMIT 10")


┌───────┬──────────────────────┬──────────────────────┬────────────────────────────────────────────────────────────────┐
│ index │         text         │         url          │                             title                              │
│ int64 │       varchar        │       varchar        │                            varchar                             │
├───────┼──────────────────────┼──────────────────────┼────────────────────────────────────────────────────────────────┤
│     [1;36m0[0m │ The [1;36m2024[0m Olympics …  │ [4;94mhttps://www.bbc.co[0m…  │ Paris [1;36m2024[0m Olympic opening ceremony kicks off Games in uniqu…  │
│     [1;36m1[0m │ Swapping a stadium…  │ [4;94mhttps://www.bbc.co[0m…  │ Paris [1;36m2024[0m Olympic opening ceremony kicks off Games in uniqu…  │
│     [1;36m2[0m │ Blue, white and re…  │ [4;94mhttps://www.bbc.co[0m…  │ Paris [1;36m2024[0m Olympic opening ceremony kicks off Games in uniqu…  │
│     [1;36m3[0m │ There were surpris…  

## Create Embeddings

In [5]:
con.sql("""
ALTER TABLE olympics 
ADD COLUMN embeddings_1024 FLOAT[1024];
""")

In [6]:
import llama_cpp

In [7]:
llm = llama_cpp.Llama(
  model_path="./models/mxbai-embed-large-v1-f16.gguf", 
  embedding=True, 
  verbose=False
)

In [8]:
rows = con.sql("SELECT index, text FROM olympics").fetchall()

In [9]:
embeddings = llm.create_embedding([text for index, text in rows])['data']

In [10]:
for (index, text), embedding in zip(rows, embeddings):
  con.execute(
    "UPDATE olympics SET embeddings_1024 = ? WHERE index = ?", 
    [embedding['embedding'], index]
  )

In [11]:
con.sql("SELECT * FROM olympics LIMIT 10""")


┌───────┬──────────────────────┬──────────────────────┬──────────────────────┬─────────────────────────────────────────┐
│ index │         text         │         url          │        title         │             embeddings_1024             │
│ int64 │       varchar        │       varchar        │       varchar        │               float[1m[[0m[1;36m1024[0m[1m][0m               │
├───────┼──────────────────────┼──────────────────────┼──────────────────────┼─────────────────────────────────────────┤
│     [1;36m0[0m │ The [1;36m2024[0m Olympics …  │ [4;94mhttps://www.bbc.co[0m…  │ Paris [1;36m2024[0m Olympic…  │ [1m[[0m[1;36m-0.65678865[0m, [1;36m0.375833[0m, [1;36m-0.2038792[0m, [1;36m0[0m…  │
│     [1;36m1[0m │ Swapping a stadium…  │ [4;94mhttps://www.bbc.co[0m…  │ Paris [1;36m2024[0m Olympic…  │ [1m[[0m[1;36m-0.3371891[0m, [1;36m0.25818387[0m, [1;36m0.2695669[0m, [1;36m0[0m…  │
│     [1;36m2[0m │ Blue, white and re…  │ [4;94mhttps://www.bb

In [12]:
import numpy as np

## Create truncated embeddings

In [13]:
def normalize(vec: list[float]) -> list[float]:
    return (vec / np.linalg.norm(vec)).tolist()

In [14]:
con.create_function(name="normalize", function=normalize)

[1m<[0m[1;95mduckdb.duckdb.DuckDBPyConnection[0m[39m object at [0m[1;36m0x1058ae2b0[0m[1m>[0m

In [15]:
con.sql("""
SELECT index, text, 
       embeddings_1024, 
       normalize(embeddings_1024[:512])::float[512] AS embeddings_512
FROM olympics 
LIMIT 5
""")


┌───────┬──────────────────────┬──────────────────────┬────────────────────────────────────────────────────────────────┐
│ index │         text         │   embeddings_1024    │                         embeddings_512                         │
│ int64 │       varchar        │     float[1m[[0m[1;36m1024[0m[1m][0m      │                           float[1m[[0m[1;36m512[0m[1m][0m                           │
├───────┼──────────────────────┼──────────────────────┼────────────────────────────────────────────────────────────────┤
│     [1;36m0[0m │ The [1;36m2024[0m Olympics …  │ [1m[[0m[1;36m-0.65678865[0m, [1;36m0.37[0m…  │ [1m[[0m[1;36m-0.05296788[0m, [1;36m0.030309716[0m, [1;36m-0.016442198[0m, [1;36m0.010034871[0m, [1;36m-0.023[0m…  │
│     [1;36m1[0m │ Swapping a stadium…  │ [1m[[0m[1;36m-0.3371891[0m, [1;36m0.258[0m…  │ [1m[[0m[1;36m-0.030085044[0m, [1;36m0.023035955[0m, [1;36m0.024051582[0m, [1;36m0.048912287[0m, [1;36m-0.044[0m…  │


In [16]:
dimensions = [16, 32, 64, 128, 256, 512]
for dimension in dimensions:
  con.sql(f"""
  ALTER TABLE olympics 
  ADD COLUMN embeddings_{dimension} FLOAT[{dimension}];
  """)

  con.sql(f"""
  UPDATE olympics 
  SET embeddings_{dimension} = normalize(embeddings_1024[:{dimension}])
  """)

In [17]:
def vector_search(query, dimension=1024):
  raw_embedding = llm.create_embedding([query])
  search_vector = raw_embedding['data'][0]['embedding']

  if dimension < len(search_vector):
    search_vector = normalize(search_vector[:dimension])

  return con.sql(f"""
  SELECT index, text,
          array_cosine_similarity(
            "embeddings_{dimension}", $searchVector::FLOAT[{dimension}]
          ) AS score
  FROM olympics
  ORDER BY score DESC
  LIMIT 3
  """, params={"searchVector": search_vector})

## Query embeddings

In [18]:
query = 'Where did the opening ceremony take place?'
vector_search(query)


┌───────┬──────────────────────────────────────────────────────────────────────────────────────────────────┬───────────┐
│ index │                                               text                                               │   score   │
│ int64 │                                             varchar                                              │   float   │
├───────┼──────────────────────────────────────────────────────────────────────────────────────────────────┼───────────┤
│     [1;36m0[0m │ The [1;36m2024[0m Olympics opened in Paris in spectacular style with thousands of athletes sailing alon…  │ [1;36m0.6873059[0m │
│    [1;36m29[0m │ The peace anthem, part of all Olympic opening ceremonies, is aligned with the message of unity…  │ [1;36m0.6783502[0m │
│    [1;36m10[0m │ When organisers first revealed plans to hold the opening ceremony along the river in the heart…  │ [1;36m0.6705272[0m │
└───────┴──────────────────────────────────────────────────────────────────

In [19]:
from rich.console import Console
c = Console()

In [20]:
with c.pager(styles=True):
  dimensions = [16, 32, 64, 128, 256, 512, 1024]
  for dimension in dimensions[::-1]:
    c.print(dimension)
    c.print(vector_search(query, dimension))

[1;36m1024[0m
┌───────┬──────────────────────────────────────────────────────────────────────────────────────────────────┬───────
────┐
│ index │                                               text                                               │   
score   │
│ int64 │                                             varchar                                              │   
float   │
├───────┼──────────────────────────────────────────────────────────────────────────────────────────────────┼───────
────┤
│     [1;36m0[0m │ The [1;36m2024[0m Olympics opened in Paris in spectacular style with thousands of athletes sailing alon…  │ 
[1;36m0.6873059[0m │
│    [1;36m29[0m │ The peace anthem, part of all Olympic opening ceremonies, is aligned with the message of unity…  │ 
[1;36m0.6783502[0m │
│    [1;36m10[0m │ When organisers first revealed plans to hold the opening ceremony along the river in the heart…  │ 
[1;36m0.6705272[0m │
└───────┴────────────────────────────────────────────

In [24]:
from ranx import Qrels, Run, compare
from functools import partial

## Evalute embeddings

In [25]:
qrels = Qrels.from_file("data/questions.json")

In [26]:
functions = [
  (partial(vector_search, dimension=dimension), dimension)
  for dimension in dimensions
]

In [27]:
def create_run(qrels, retrieval_fn, name):
  run_dict = {
    question: {
      str(index): score
      for index, score in (retrieval_fn(question)
                            .select("index, score")
                            .fetchall()
                          )
    }
    for question in qrels.to_dict()
  }
  return Run(run_dict, name=name)

In [28]:
%%time
runs = [
   create_run(qrels, fn, name)
   for fn, name in functions
]

CPU times: user 2min 21s, sys: 3.56 s, total: 2min 25s
Wall time: 21.8 s


In [29]:
comparison = compare(
    qrels,
    runs=runs,
    metrics=["hit_rate"],
)

  scores[i] = _hit_rate(qrels[i], run[i], k, rel_lvl)


In [30]:
comparison


#      Model  Hit Rate
---  -------  ----------
a         [1;36m16[0m  [1;36m0.550[0m
b         [1;36m32[0m  [1;36m0.350[0m
c         [1;36m64[0m  [1;36m0.[0m700ᵇ
d        [1;36m128[0m  [1;36m0.[0m850ᵇ
e        [1;36m256[0m  [1;36m0.[0m800ᵇ
f        [1;36m512[0m  [1;36m0.[0m950ᵃᵇ
g       [1;36m1024[0m  [1;36m0.[0m750ᵇ

In [31]:
with c.pager(styles=True):
  c.print(comparison.win_tie_loss)

[1m{[0m
    [1m([0m[1;36m16[0m, [1;36m32[0m[1m)[0m: [1m{[0m[32m'hit_rate'[0m: [1m{[0m[32m'W'[0m: [1;36m5[0m, [32m'T'[0m: [1;36m14[0m, [32m'L'[0m: [1;36m1[0m[1m}[0m[1m}[0m,
    [1m([0m[1;36m16[0m, [1;36m64[0m[1m)[0m: [1m{[0m[32m'hit_rate'[0m: [1m{[0m[32m'W'[0m: [1;36m3[0m, [32m'T'[0m: [1;36m11[0m, [32m'L'[0m: [1;36m6[0m[1m}[0m[1m}[0m,
    [1m([0m[1;36m16[0m, [1;36m128[0m[1m)[0m: [1m{[0m[32m'hit_rate'[0m: [1m{[0m[32m'W'[0m: [1;36m2[0m, [32m'T'[0m: [1;36m10[0m, [32m'L'[0m: [1;36m8[0m[1m}[0m[1m}[0m,
    [1m([0m[1;36m16[0m, [1;36m256[0m[1m)[0m: [1m{[0m[32m'hit_rate'[0m: [1m{[0m[32m'W'[0m: [1;36m3[0m, [32m'T'[0m: [1;36m9[0m, [32m'L'[0m: [1;36m8[0m[1m}[0m[1m}[0m,
    [1m([0m[1;36m16[0m, [1;36m512[0m[1m)[0m: [1m{[0m[32m'hit_rate'[0m: [1m{[0m[32m'W'[0m: [1;36m1[0m, [32m'T'[0m: [1;36m10[0m, [32m'L'[0m: [1;36m9[0m[1m}[0m[1m}[0m,
    [1m([0m[1;3