# Funnel Search with Matryoshka Embeddings

Let's revisit the example from the [Milvus and Sentence Transformers doc](???) of semantic search on the Wikipedia Movie Plots dataset. First our imports:

In [1]:
import functools

from datasets import load_dataset
import numpy as np
import pandas as pd
import pymilvus
from pymilvus import MilvusClient, connections
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
from sentence_transformers import SentenceTransformer
import torch
import torch.nn.functional as F
from tqdm import tqdm

## Load Matryoshka Embedding Model

Instead of using the regular embedding model [`sentence-transformers/all-MiniLM-L12-v2`](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) as in the previous example, we use [a model from Nomic](https://huggingface.co/nomic-ai/nomic-embed-text-v1) trained to produce Matryoshka embeddings.

In [2]:
model = SentenceTransformer(
    "nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True, device="mps"
)

<All keys matched successfully>


## Loading Dataset, Embedding Items, and Building Vector Database

We'll abbreviate the discussion of loading and embeddings the data - refer to the docs for more details. Two changes we have made are changing the embedding dimension and adding a prefix to the document embeddings, which is a requirement from how the model was trained.

Milvus does not currently support searching over subsets of embeddings, so we break the embeddings into two parts: the head represents the initial subset of the vector to index and search, and the tail is the remainder. The model is trained for cosine distance similarity search, so we normalize the head embeddings. However, in order to calculate similarities for larger subsets later on, we need to store the norm of the head embedding, so we can unnormalize it before joining to the tail.

In [3]:
embedding_dim = 768
search_dim = 128
collection_name = "movie_embeddings"

ds = load_dataset("vishnupriyavr/wiki-movie-plots-with-summaries", split="train")
print(ds)

client = MilvusClient(uri="./wiki-movie-plots-matryoshka.db")

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=256),
    FieldSchema(name="head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim),
    FieldSchema(name="head_norm", dtype=DataType.FLOAT),
    FieldSchema(
        name="tail_embedding",
        dtype=DataType.FLOAT_VECTOR,
        dim=embedding_dim - search_dim,
    ),
    # These fields are not used for funnel search, only for comparing to alternative methods
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
    FieldSchema(
        name="flip_head_embedding", dtype=DataType.FLOAT_VECTOR, dim=search_dim
    ),
    FieldSchema(name="flip_head_norm", dtype=DataType.FLOAT),
    FieldSchema(
        name="flip_tail_embedding",
        dtype=DataType.FLOAT_VECTOR,
        dim=embedding_dim - search_dim,
    ),
]

schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
client.create_collection(collection_name=collection_name, schema=schema)

index_params = client.prepare_index_params()
index_params.add_index(field_name="head_embedding", index_type="FLAT", metric_type="IP")
index_params.add_index(field_name="embedding", index_type="FLAT", metric_type="IP")
index_params.add_index(
    field_name="flip_head_embedding", index_type="FLAT", metric_type="IP"
)
client.create_index(collection_name, index_params)

for batch in tqdm(ds.batch(batch_size=512)):
    plot_summary = ["search_document: " + x.strip() for x in batch["PlotSummary"]]
    embeddings = model.encode(plot_summary, convert_to_tensor=True)

    # Solely for valid funnel search
    head_embeddings = embeddings[:, :search_dim]
    head_norms = torch.linalg.vector_norm(head_embeddings, dim=1)
    head_embeddings = F.normalize(head_embeddings, p=2, dim=1)
    tail_embeddings = embeddings[:, search_dim:]

    # For method comparison
    embeddings_normed = F.normalize(embeddings, p=2, dim=1)
    flip_embeddings = torch.flip(embeddings, [1])
    flip_head_embeddings = flip_embeddings[:, :search_dim]
    flip_head_norms = torch.linalg.vector_norm(flip_head_embeddings, dim=1)
    flip_head_embeddings = F.normalize(flip_head_embeddings, p=2, dim=1)
    flip_tail_embeddings = flip_embeddings[:, search_dim:]

    data = [
        {
            "title": title,
            "head_embedding": head.cpu().numpy(),
            "head_norm": float(head_norm),
            "tail_embedding": tail.cpu().numpy(),
            # For method comparison
            "embedding": embedding.cpu().numpy(),
            "flip_head_embedding": flip_head.cpu().numpy(),
            "flip_head_norm": float(flip_head_norm),
            "flip_tail_embedding": flip_tail.cpu().numpy(),
        }
        for title, head, head_norm, tail, embedding, flip_head, flip_head_norm, flip_tail in zip(
            batch["Title"],
            head_embeddings,
            head_norms,
            tail_embeddings,
            embeddings_normed,
            flip_head_embeddings,
            flip_head_norms,
            flip_tail_embeddings,
        )
    ]
    res = client.insert(collection_name=collection_name, data=data)

Dataset({
    features: ['Release Year', 'Title', 'Origin/Ethnicity', 'Director', 'Cast', 'Genre', 'Wiki Page', 'Plot', 'PlotSummary'],
    num_rows: 34886
})


100%|██████████| 69/69 [06:31<00:00,  5.67s/it]


## Performing Funnel Search
Let's now implement a "funnel search" using the smaller part of the Matryoshka embeddings. In the process, we will also obtain results for a search that just uses the smaller part of the embeddings without performing score, re-rank, prune operations.

In [5]:
queries = [
    "A movie about a shark that terrorizes an LA beach.",
    "An archaeologist searches for ancient artifacts while fighting Nazis.",
    "Teenagers in detention learn about themselves.",
    "A teenager fakes illness to get off school and have adventures with two friends.",
    "A young couple with a kid look after a hotel during winter and the husband goes insane.",
    "Four turtles fight bad guys.",
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]


instruct_queries = ["search_query: " + q.strip() for q in queries]
search_data = embed_search(instruct_queries)

# Normalize head embeddings
head_search = [x[:search_dim] / np.linalg.norm(x[:search_dim]) for x in search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=head_search,
    anns_field="head_embedding",
    limit=128,
    output_fields=["title", "head_embedding", "head_norm", "tail_embedding"],
)

We'll print the search results before performing the funnel operations for comparison to the other methods.

In [7]:
for query, hits in zip(queries, res):
    rows = [x['entity'] for x in hits][:7]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row['title'].strip())
    print()

Query: A movie about a shark that terrorizes an LA beach.
Results:
The Shallows
Bait 3D
Tintorera
2-Headed Shark Attack
The Life Aquatic with Steve Zissou
Tiger Shark
Deliver Us from Evil

Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
The Passage
Counterblast
Dominion: Prequel to the Exorcist
A Yank in Libya
Charlie Chan in Egypt

Query: Teenagers in detention learn about themselves.
Results:
18
Dusari Goshta
Bad Kids Go to Hell
Amar Bondhu Rashed
Punishment Park
For You I Die
Nowhere to Run

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
How to Deal
Shorts
Blackbird
Valentine
Unfriended
Dear Friends
Texas Chainsaw Massacre: The Next Generation

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
Ghostkeeper
Our Vines Have Tender Grapes
The Ref
Impact
The House in Marsh Road
Daddy's Home 2
Tyrannosaur

Query: Four 

For ease of exposition of the funnel search algorithm, we convert the Milvus search hits for each query into a Pandas dataframe.

In [16]:
def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:
    """
    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.

    NOTE: We have to unnormalize head embedding so we can correctly normalize superset Matryoshka embeddings. This
    is why we stored head_norm field.
    """
    rows = [x['entity'] for x in hits]
    rows_dict = [
        {
            "title": x['title'],
            "embedding": torch.concat(
                [
                    torch.tensor(x['head_embedding']) * x['head_norm'],
                    torch.tensor(x['tail_embedding']),
                ],
                dim=0,
            ),
        }
        for x in rows
    ]
    return pd.DataFrame.from_records(rows_dict)


dfs = [hits_to_dataframe(hits) for hits in res]

In [17]:
dfs[0]

Unnamed: 0,title,embedding
0,The Shallows,"[tensor(0.9796), tensor(0.3610), tensor(-3.848..."
1,Bait 3D,"[tensor(-0.2942), tensor(0.3814), tensor(-4.21..."
2,Tintorera,"[tensor(0.6867), tensor(0.7650), tensor(-4.588..."
3,2-Headed Shark Attack,"[tensor(-0.4027), tensor(0.9793), tensor(-3.92..."
4,The Life Aquatic with Steve Zissou,"[tensor(-0.1876), tensor(0.8757), tensor(-4.20..."
...,...,...
123,Meatballs Part II,"[tensor(-0.7218), tensor(0.5323), tensor(-3.96..."
124,Ravana Desam,"[tensor(0.3206), tensor(0.0211), tensor(-3.998..."
125,Combat Shock,"[tensor(-0.3900), tensor(-0.1212), tensor(-4.4..."
126,Three Godfathers,"[tensor(-1.1591), tensor(-0.2948), tensor(-4.4..."


Now, to perform funnel search we iterate over the increasingly larger subsets of the vectors on which we we would like to perform a reranking plus pruning.

In [18]:
# An optimized implementation would vectorize the calculation of similarity scores across rows (using a matrix)
def calculate_score(row, query_emb=None, dims=768):
    emb = F.normalize(row["embedding"][:dims], dim=-1)
    return (emb @ query_emb).item()


# You could also add a top-K parameter as a termination condition
def funnel_search(
    df: pd.DataFrame, query_emb, scales=[256, 512, 768], prune_ratio=0.5
) -> pd.DataFrame:
    for dims in scales:
        # Query vector must be normalized for each new dimensionality
        # query_emb = torch.tensor(search_data[0][:dims] / np.linalg.norm(search_data[0][:dims]))
        emb = torch.tensor(query_emb[:dims] / np.linalg.norm(query_emb[:dims]))

        # Score
        scores = df.apply(
            functools.partial(calculate_score, query_emb=emb, dims=dims), axis=1
        )
        df["scores"] = scores

        # Re-rank
        df.sort_values(by="scores", inplace=True, ascending=False)
        df.head()

        # Prune (in our case, remove half of candidates at each step)
        df = df.head(int(prune_ratio * len(df)))

    return df


dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, search_data)
]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["scores"] = scores
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.sort_values(by="scores", inplace=True, ascending=False)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["scores"] = scores
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guid

In [20]:
for d in dfs_results:
    print(d["query"], "\n", d["results"][:7]["title"], "\n")

A movie about a shark that terrorizes an LA beach. 
 0                           The Shallows
3                  2-Headed Shark Attack
1                                Bait 3D
14                     Jaws: The Revenge
4     The Life Aquatic with Steve Zissou
69                                  Jaws
2                              Tintorera
Name: title, dtype: object 

An archaeologist searches for ancient artifacts while fighting Nazis. 
 0            "Pimpernel" Smith
1                Black Hunters
29     Raiders of the Lost Ark
34              The Master Key
51             My Gun Is Quick
2                  The Passage
109                   Primeval
Name: title, dtype: object 

Teenagers in detention learn about themselves. 
 9            The Breakfast Club
100    The Explosive Generation
2           Bad Kids Go to Hell
11             Bratz: The Movie
39             Band of the Hand
72               Up the Academy
26                 Jail Busters
Name: title, dtype: object 

A teenager 

Qualitatively, these results seem to have higher recall than the standard vector search in the tutorial, ["Movie Search using Milvus and Sentence Transformers"](https://milvus.io/docs/integrate_with_sentencetransformers.md), which uses a different embedding model. (See the paper [Matryoshka](https://arxiv.org/abs/2205.13147) for further experiments and benchmarking.)

## Comparing Funnel Search to Regular Search

Let's compare the results of our funnel search to a standard vector search *on the same dataset with the same embedding model*.

In [24]:
# Normalize head embeddings
search_embs = [x / np.linalg.norm(x) for x in search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=search_embs,
    anns_field="embedding",
    limit=16,
    output_fields=["title", "embedding"],
)

In [25]:
for query, hits in zip(queries, res):
    rows = [x['entity'] for x in hits][:7]

    print("Query:", query)
    print("Results:")
    for row in rows:
        print(row['title'].strip())
    print()

Query: A movie about a shark that terrorizes an LA beach.
Results:
The Shallows
2-Headed Shark Attack
Bait 3D
Jaws: The Revenge
The Life Aquatic with Steve Zissou
Jaws
Tintorera

Query: An archaeologist searches for ancient artifacts while fighting Nazis.
Results:
"Pimpernel" Smith
Black Hunters
Raiders of the Lost Ark
The Master Key
My Gun Is Quick
The Passage
The Mole People

Query: Teenagers in detention learn about themselves.
Results:
The Breakfast Club
The Explosive Generation
Bad Kids Go to Hell
Bratz: The Movie
Band of the Hand
Up the Academy
Jail Busters

Query: A teenager fakes illness to get off school and have adventures with two friends.
Results:
A Walk to Remember
Ferris Bueller's Day Off
How I Live Now
On the Edge of Innocence
Bratz: The Movie
Unfriended
Simon Says

Query: A young couple with a kid look after a hotel during winter and the husband goes insane.
Results:
The Shining
Ghostkeeper
Fast and Loose
Killing Ground
Home Alone
Home Alone 2: Lost in New York
Leopard 

With the exception of the results for "A teenager fakes illness to get off school...", the results under funnel search are almost identical to the full search, even though the funnel search was performed on a search space of 128 dimensions vs 768 dimensions for the regular one.		

## Does the order matter? Suffix vs prefix embeddings.

The model was trained to perform well matching recursively smaller prefixes of the embeddings. Does the order of the dimensions we use matter? For instance, could we also take subsets of the embeddings that are suffixes? In this experiment, we reverse the order of the dimensions in the Matryoshka embeddings and perform a funnel search.

In [36]:
# Normalize head embeddings
flip_search_data = [torch.flip(torch.tensor(x), dims=[-1]).cpu().numpy() for x in search_data]
flip_head_search = [x[:search_dim] / np.linalg.norm(x[:search_dim]) for x in flip_search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=flip_head_search,
    anns_field="flip_head_embedding",
    limit=128,
    output_fields=["title", "flip_head_embedding", "flip_head_norm", "flip_tail_embedding"],
)

In [37]:
def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:
    """
    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.

    NOTE: We have to unnormalize head embedding so we can correctly normalize superset Matryoshka embeddings. This
    is why we stored head_norm field.
    """
    rows = [x['entity'] for x in hits]
    rows_dict = [
        {
            "title": x['title'],
            "embedding": torch.concat(
                [
                    torch.tensor(x['flip_head_embedding']) * x['flip_head_norm'],
                    torch.tensor(x['flip_tail_embedding']),
                ],
                dim=0,
            ),
        }
        for x in rows
    ]
    return pd.DataFrame.from_records(rows_dict)


dfs = [hits_to_dataframe(hits) for hits in res]

In [38]:
# # An optimized implementation would vectorize the calculation of similarity scores across rows (using a matrix)
# def calculate_score(row, query_emb=None, dims=768):
#     emb = F.normalize(row["embedding"][:dims], dim=-1)
#     return (emb @ query_emb).item()


# # You could also add a top-K parameter as a termination condition
# def funnel_search(
#     df: pd.DataFrame, query_emb, scales=[256, 512, 768], prune_ratio=0.5
# ) -> pd.DataFrame:
#     for dims in scales:
#         # Query vector must be normalized for each new dimensionality
#         # query_emb = torch.tensor(search_data[0][:dims] / np.linalg.norm(search_data[0][:dims]))
#         emb = torch.tensor(query_emb[:dims] / np.linalg.norm(query_emb[:dims]))

#         # Score
#         scores = df.apply(
#             functools.partial(calculate_score, query_emb=emb, dims=dims), axis=1
#         )
#         df["scores"] = scores

#         # Re-rank
#         df.sort_values(by="scores", inplace=True, ascending=False)
#         df.head()

#         # Prune (in our case, remove half of candidates at each step)
#         df = df.head(int(prune_ratio * len(df)))

#     return df


dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, flip_search_data)
]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["scores"] = scores
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.sort_values(by="scores", inplace=True, ascending=False)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["scores"] = scores
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guid

In [39]:
for d in dfs_results:
    print(d["query"], "\n", d["results"][:7]["title"], "\n")

A movie about a shark that terrorizes an LA beach. 
 1                           The Shallows
10                 2-Headed Shark Attack
49                               Bait 3D
17                     Jaws: The Revenge
3     The Life Aquatic with Steve Zissou
8                                   Jaws
93                             Tintorera
Name: title, dtype: object 

An archaeologist searches for ancient artifacts while fighting Nazis. 
 1            "Pimpernel" Smith
18               Black Hunters
3      Raiders of the Lost Ark
48              The Master Key
112            My Gun Is Quick
43                 The Passage
56             The Mole People
Name: title, dtype: object 

Teenagers in detention learn about themselves. 
 0          Bad Kids Go to Hell
23            Band of the Hand
48               Take the Lead
22                 Silent Fall
16     Riot in Juvenile Prison
123         Within These Walls
29                       Pluto
Name: title, dtype: object 

A teenager fakes i

Recall is much poorer than funnel search or regular search as expected.

## Investigating Funnel Search Recall Failure for Ferris Bueller's Day Off

Why didn't funnel search succeed in retrieving Ferris Bueller's Day Off? Let's examine whether it wasn't in the original candidate list, or was mistakenly filtered out.

In [50]:
queries = [
    "A teenager fakes illness to get off school and have adventures with two friends."
]


# Search the database based on input text
def embed_search(data):
    embeds = model.encode(data)
    return [x for x in embeds]

instruct_queries = ["search_query: " + q.strip() for q in queries]
search_data = embed_search(instruct_queries)

# Normalize head embeddings
head_search = [x[:search_dim] / np.linalg.norm(x[:search_dim]) for x in search_data]

# Perform standard vector search on subset of embeddings
res = client.search(
    collection_name=collection_name,
    data=head_search,
    anns_field="head_embedding",
    limit=256,
    output_fields=["title", "head_embedding", "head_norm", "tail_embedding"],
)

In [54]:
for query, hits in zip(queries, res):
    rows = [x['entity'] for x in hits]

    print("Query:", query)
    for idx, row in enumerate(rows):
        if row['title'].strip() == "Ferris Bueller's Day Off":
            print(f"Row {idx}: Ferris Bueller's Day Off")

Query: A teenager fakes illness to get off school and have adventures with two friends.
Row 228: Ferris Bueller's Day Off


We see that the issue was that the initial candidate list was not large enough, or rather, the desired hit is not similar enough to the query at the highest level of granularity. Changing it from `128` to `256` results in successful retrieval. We should form a rule-of-thumb to set the number of candidates on a held-out set to empirically evaluate the trade-off between recall and latency.

In [55]:
def hits_to_dataframe(hits: pymilvus.client.abstract.Hits) -> pd.DataFrame:
    """
    Convert a Milvus search result to a Pandas dataframe. This function is specific to our data schema.

    NOTE: We have to unnormalize head embedding so we can correctly normalize superset Matryoshka embeddings. This
    is why we stored head_norm field.
    """
    rows = [x['entity'] for x in hits]
    rows_dict = [
        {
            "title": x['title'],
            "embedding": torch.concat(
                [
                    torch.tensor(x['head_embedding']) * x['head_norm'],
                    torch.tensor(x['tail_embedding']),
                ],
                dim=0,
            ),
        }
        for x in rows
    ]
    return pd.DataFrame.from_records(rows_dict)


dfs = [hits_to_dataframe(hits) for hits in res]

dfs_results = [
    {"query": query, "results": funnel_search(df, query_emb)}
    for query, df, query_emb in zip(queries, dfs, search_data)
]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["scores"] = scores
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df.sort_values(by="scores", inplace=True, ascending=False)


In [57]:
for d in dfs_results:
    print(d["query"], "\n", d["results"][:7]["title"], "\n")

A teenager fakes illness to get off school and have adventures with two friends. 
 137          A Walk to Remember
228    Ferris Bueller's Day Off
21               How I Live Now
32     On the Edge of Innocence
77             Bratz: The Movie
4                    Unfriended
108                  Simon Says
Name: title, dtype: object 

