# Embedding comparison

We'll need text embeddings to search for the relevant documents in the knowledge base. We've generated questions using text chunks in `generate_data.ipynb`. In this notebook I used them to calculate metrics for different embedding models and compare them

In [1]:
import numpy as np
import pandas as pd
import torch
import tqdm
from transformers import AutoModel, AutoTokenizer, T5EncoderModel
import time
import torch.nn.functional as F
import faiss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_db = pd.read_csv("df.csv", index_col=0)  # loading knowledge base df

In [3]:
df_db

Unnamed: 0,section,subsection,question,answer,text,hash_answer
0,Classical models,Linear Regression,Regression _1,Regression in machine learning refers to a sup...,Classical models\nLinear Regression\nRegressio...,8f8499b5f59e9390a87f7d2b183cc8bd
1,Classical models,Linear Regression,Regression _2,regression.\n4. Ridge & Lasso Regression\nRidg...,Classical models\nLinear Regression\nRegressio...,a37096af9620af5eca2a696c03a4b397
2,Classical models,Linear Regression,What Is a Linear Regression Model? List Its Dr...,A linear regression model is a model in which ...,Classical models\nLinear Regression\nWhat Is a...,376cf3108393d26d6d09952af3a4f1b8
3,Classical models,Linear Regression,What are various assumptions used in linear re...,Linear regression is done under the following ...,Classical models\nLinear Regression\nWhat are ...,cc89d249384cd42bccf680fb513ae05c
4,Classical models,Linear Regression,What methods for solving linear regression do ...,"To solve linear regression, you need to find t...",Classical models\nLinear Regression\nWhat meth...,c7811418f1a69095d8bd9c190adac605
...,...,...,...,...,...,...
94,Probability and Statistics,Miscellaneous,How do you identify if a coin is biased?_1,We collect data by flipping the coin 200 times...,Probability and Statistics\nMiscellaneous\nHow...,da4450cba0a58f49d04514b63d6c662d
95,Probability and Statistics,Miscellaneous,How do you identify if a coin is biased?_2,observed value arising by chance is only 1 in ...,Probability and Statistics\nMiscellaneous\nHow...,7227d237c94a82f26a43bcbcf9214ffb
96,Probability and Statistics,Miscellaneous,What does Design of Experiments mean?_1,"Design of experiments also known as DOE, it is...",Probability and Statistics\nMiscellaneous\nWha...,1b68b299f67ada223609edc3dc1a7bc1
97,Probability and Statistics,Miscellaneous,"Given uniform distribution X and Y (mean 0, SD...",0.5,Probability and Statistics\nMiscellaneous\nGiv...,d310cb367d993fb6fb584b198a2fd72c


In [4]:
df_qa = pd.read_csv("df_qa.csv", index_col=0)  # loading generated questions

In [5]:
df_qa = df_qa[:-500].dropna()

In [6]:
df_chunks = df_db.reset_index(names=['chunk_id'])[['chunk_id', 'text']]
df_chunks

Unnamed: 0,chunk_id,text
0,0,Classical models\nLinear Regression\nRegressio...
1,1,Classical models\nLinear Regression\nRegressio...
2,2,Classical models\nLinear Regression\nWhat Is a...
3,3,Classical models\nLinear Regression\nWhat are ...
4,4,Classical models\nLinear Regression\nWhat meth...
...,...,...
640,94,Probability and Statistics\nMiscellaneous\nHow...
641,95,Probability and Statistics\nMiscellaneous\nHow...
642,96,Probability and Statistics\nMiscellaneous\nWha...
643,97,Probability and Statistics\nMiscellaneous\nGiv...


Now we can merge them to get question and corresponding text chunks ids

In [7]:
df_q = pd.merge(df_qa, df_chunks, left_on='Context', right_on='text').reset_index(names=["query_id"])[['query_id', 'Question', 'chunk_id']]
df_q

Unnamed: 0,query_id,Question,chunk_id
0,0,What is the main goal of regression in machine...,0
1,1,What are the two types of variables present in...,0
2,2,What type of regression is used when there is ...,0
3,3,What type of regression is used to model non-l...,0
4,4,What are the extensions of linear regression t...,0
...,...,...,...
3219,3219,What does entropy measure in a probability dis...,98
3220,3220,What is the relationship between entropy and p...,98
3221,3221,What is the formula for calculating entropy in...,98
3222,3222,What does cross-entropy measure?,98


In [15]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [39]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"

In [41]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [43]:
model = AutoModel.from_pretrained(model_name).to(device)

In [45]:
model.num_parameters()

22713216

In [17]:
def pool(hidden_state, mask, pooling_method="cls"):
    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif pooling_method == "cls":
        return hidden_state[:, 0]

In [19]:
def get_embeddings(model, tokenizer, texts, batch_size=8):
    embeddings = []
    for i in tqdm.tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i : i + batch_size]
        tokenized = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
        tokenized = {k: v.to(device) for k, v in tokenized.items()}
        output = model(**tokenized)
        embedding = pool(
            output.last_hidden_state, 
            tokenized["attention_mask"],
            pooling_method="cls"
        )
        embedding = F.normalize(embedding, p=2, dim=1)
        embeddings.append(embedding.cpu().detach().numpy())
    return np.concatenate(embeddings)

In [47]:
questions = df_q['Question'].tolist()

In [49]:
texts = df_chunks["text"].to_list()

In [51]:
ch_embeddings = get_embeddings(model, tokenizer, texts)  # texts chunks embeddings

100%|██████████| 81/81 [00:48<00:00,  1.68it/s]


In [52]:
q_embeddings = get_embeddings(model, tokenizer, questions)  # question embeddings

100%|██████████| 403/403 [00:14<00:00, 27.55it/s]


In [53]:
dim = ch_embeddings.shape[1]

In [54]:
faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(dim))
faiss_index.add_with_ids(ch_embeddings, df_chunks["chunk_id"])

In [55]:
def search(faiss_index, q_embeddings, top_k=5, batch_size=128):
    # finding closest to the query chunks in the index
    scores_list = []
    ids_list = []
    for i in tqdm.tqdm(range(0, len(q_embeddings), batch_size)):
        scores, ids = faiss_index.search(q_embeddings[i: i + batch_size], top_k)
        scores_list.append(scores)
        ids_list.append(ids)
    return np.concatenate(scores_list), np.concatenate(ids_list)

In [64]:
scores, ids = search(faiss_index, q_embeddings)

100%|██████████| 26/26 [00:00<00:00, 1386.74it/s]


In [66]:
ids.shape

(3224, 5)

In [68]:
gt = df_q["chunk_id"].to_numpy()

In [70]:
def get_recall_at_k(ids, gt):
    return float((ids == gt[:, np.newaxis]).any(axis=1).mean())

In [72]:
recall_at_k = get_recall_at_k(ids, gt)
recall_at_k

0.875

In [74]:
def get_mrr(ids, gt):
    matches = ids == gt[:, np.newaxis]
    ranks = np.argmax(matches, axis=1) + 1
    reciprocal_ranks = 1 / ranks
    reciprocal_ranks[~matches.any(axis=1)] = 0
    return float(np.mean(reciprocal_ranks))

In [76]:
mrr = get_mrr(ids, gt)
mrr

0.7225909842845326

All together:

In [29]:
def get_metrics(model_name, df_chunks, df_q, k=5):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    params_num = model.num_parameters()
    
    questions = df_q['Question'].tolist()
    texts = df_chunks["text"].to_list()
    
    ch_embeddings = get_embeddings(model, tokenizer, texts)
    t = time.time()
    q_embeddings = get_embeddings(model, tokenizer, questions)
    mean_emb_t = (time.time() - t) / len(questions)
    dim = ch_embeddings.shape[1]
        
    faiss_index = faiss.IndexIDMap(faiss.IndexFlatL2(dim))
    faiss_index.add_with_ids(ch_embeddings, df_chunks["chunk_id"])

    t = time.time()
    scores, ids = search(faiss_index, q_embeddings)
    mean_search_time = (time.time() - t) / len(questions)
    gt = df_q["chunk_id"].to_numpy()
    recall_at_k = get_recall_at_k(ids, gt)
    mrr = get_mrr(ids, gt)

    return {"model_name": model_name, "dim": dim, "params_num": params_num, "recall_at_k": recall_at_k, "mrr": mrr, "mean_emb_time": mean_emb_t, "mean_search_time": mean_search_time}

In [31]:
model_names = [
    "BAAI/bge-m3",
    "intfloat/e5-small-v2",
    "sentence-transformers/all-MiniLM-L6-v2"
]

In [33]:
metric_list = []

for model_name in model_names:
    metric_list.append(get_metrics(model_name, df_chunks, df_q))

100%|██████████| 81/81 [07:04<00:00,  5.24s/it]
100%|██████████| 403/403 [02:21<00:00,  2.84it/s]
100%|██████████| 26/26 [00:00<00:00, 545.58it/s]
100%|██████████| 81/81 [01:39<00:00,  1.23s/it]
100%|██████████| 403/403 [00:27<00:00, 14.62it/s]
100%|██████████| 26/26 [00:00<00:00, 1530.49it/s]
100%|██████████| 81/81 [00:44<00:00,  1.84it/s]
100%|██████████| 403/403 [00:14<00:00, 27.66it/s]
100%|██████████| 26/26 [00:00<00:00, 1536.09it/s]


In [35]:
df_metrics = pd.DataFrame(metric_list)

In [37]:
df_metrics

Unnamed: 0,model_name,dim,params_num,recall_at_k,mrr,mean_emb_time,mean_search_time
0,BAAI/bge-m3,1024,567754752,0.957196,0.84909,0.044,1.5e-05
1,intfloat/e5-small-v2,384,33360000,0.880583,0.723056,0.008555,5e-06
2,sentence-transformers/all-MiniLM-L6-v2,384,22713216,0.875,0.722591,0.004522,6e-06


We got results for every model and now can choose between BGE-M3 with best metrics or faster and lighter e5-small-v2 or all-MiniLM-L6-v2.