# Reranker

A pipeline for RAG normally looks like this:
* Retriever → Reranker → Top-k → generator (LLM)

I want to compare retrieval metrics with and without a reranker

In [3]:
import numpy as np
import pandas as pd
import torch
import tqdm
from transformers import AutoModel, AutoTokenizer
import time
import torch.nn.functional as F
import faiss

from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
df_db = pd.read_csv("df.csv", index_col=0)  # loading knowledge base df
df_qa = pd.read_csv("df_qa.csv", index_col=0)  # loading generated questions
df_qa = df_qa[:-500].dropna()
df_chunks = df_db.reset_index(names=['chunk_id'])[['chunk_id', 'text']]
df_q = pd.merge(df_qa, df_chunks, left_on='Context', right_on='text').reset_index(names=["query_id"])[['query_id', 'Question', 'chunk_id']]

In [6]:
emb_model_name = "sentence-transformers/all-MiniLM-L6-v2" # embedding model for retrieval

In [7]:
reranker_model_name = "cross-encoder/ms-marco-MiniLM-L6-v2"  # reranker model

In [8]:
df_db

Unnamed: 0_level_0,section,subsection,question,answer,text,hash_answer
chunk_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,Classical models,Linear Regression,Regression _1,Regression in machine learning refers to a sup...,Classical models\nLinear Regression\nRegressio...,8f8499b5f59e9390a87f7d2b183cc8bd
1,Classical models,Linear Regression,Regression _2,regression.\n4. Ridge & Lasso Regression\nRidg...,Classical models\nLinear Regression\nRegressio...,a37096af9620af5eca2a696c03a4b397
2,Classical models,Linear Regression,What Is a Linear Regression Model? List Its Dr...,A linear regression model is a model in which ...,Classical models\nLinear Regression\nWhat Is a...,376cf3108393d26d6d09952af3a4f1b8
3,Classical models,Linear Regression,What are various assumptions used in linear re...,Linear regression is done under the following ...,Classical models\nLinear Regression\nWhat are ...,cc89d249384cd42bccf680fb513ae05c
4,Classical models,Linear Regression,What methods for solving linear regression do ...,"To solve linear regression, you need to find t...",Classical models\nLinear Regression\nWhat meth...,c7811418f1a69095d8bd9c190adac605
...,...,...,...,...,...,...
640,Probability and Statistics,Miscellaneous,How do you identify if a coin is biased?_1,We collect data by flipping the coin 200 times...,Probability and Statistics\nMiscellaneous\nHow...,da4450cba0a58f49d04514b63d6c662d
641,Probability and Statistics,Miscellaneous,How do you identify if a coin is biased?_2,observed value arising by chance is only 1 in ...,Probability and Statistics\nMiscellaneous\nHow...,7227d237c94a82f26a43bcbcf9214ffb
642,Probability and Statistics,Miscellaneous,What does Design of Experiments mean?_1,"Design of experiments also known as DOE, it is...",Probability and Statistics\nMiscellaneous\nWha...,1b68b299f67ada223609edc3dc1a7bc1
643,Probability and Statistics,Miscellaneous,"Given uniform distribution X and Y (mean 0, SD...",0.5,Probability and Statistics\nMiscellaneous\nGiv...,d310cb367d993fb6fb584b198a2fd72c


In [9]:
bi_encoder = SentenceTransformer(emb_model_name, device=device)

In [10]:
corpus_embeddings = bi_encoder.encode(df_db["text"].to_list(), normalize_embeddings=True, show_progress_bar=True)

Batches: 100%|██████████| 21/21 [00:01<00:00, 19.76it/s]


In [11]:
corpus_embeddings.shape

(645, 384)

Create a FAISS index:

In [13]:
dim = corpus_embeddings.shape[1]
index = faiss.IndexFlatIP(dim)  # Inner product = cosine similarity
index.add(corpus_embeddings)

In [14]:
questions = df_q['Question'].tolist()

In [15]:
query = questions[0]
query  # example question

'What is the main goal of regression in machine learning?'

In [16]:
query_emb = bi_encoder.encode([query], normalize_embeddings=True)

In [17]:
# Top-N candidates
N = 10
scores, indices = index.search(query_emb, N)
indices

array([[  0,   4, 520, 194,   6,   1, 552,   8,  29, 521]])

In [18]:
indices = indices.tolist()[0]

print("Retriever candidates indices:", indices)

Retriever candidates indices: [0, 4, 520, 194, 6, 1, 552, 8, 29, 521]


In [19]:
df_db.iloc[indices]

Unnamed: 0_level_0,section,subsection,question,answer,text,hash_answer
chunk_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,Classical models,Linear Regression,Regression _1,Regression in machine learning refers to a sup...,Classical models\nLinear Regression\nRegressio...,8f8499b5f59e9390a87f7d2b183cc8bd
4,Classical models,Linear Regression,What methods for solving linear regression do ...,"To solve linear regression, you need to find t...",Classical models\nLinear Regression\nWhat meth...,c7811418f1a69095d8bd9c190adac605
520,Metrics,Regression,Prediction Intervals in Forecasting: Quantile ...,"In most real world prediction problems, the un...",Metrics\nRegression\nPrediction Intervals in F...,6d753fe83941d13c3078d7de7376f847
194,Data,Bias,Define and explain the concept of Inductive Bi...,Inductive Bias is a set of assumptions that hu...,Data\nBias\nDefine and explain the concept of ...,6fa2cd39c9804937d7c12a547d5a72e3
6,Classical models,Linear Regression,Ordinary least squares_1,The ordinary least squares (OLS) method can be...,Classical models\nLinear Regression\nOrdinary ...,3c9097198e3738bb62c3b32659a2e950
1,Classical models,Linear Regression,Regression _2,regression.\n4. Ridge & Lasso Regression\nRidg...,Classical models\nLinear Regression\nRegressio...,a37096af9620af5eca2a696c03a4b397
552,Probability and Statistics,Probability,What is difference between Probability and Sta...,Probability theory and statistics are often pr...,Probability and Statistics\nProbability \nWhat...,9680e186d103d93183f9182b7e0591ec
8,Classical models,Linear Regression,Bayesian Linear Regression_1,Bayesian linear regression pushes the idea of ...,Classical models\nLinear Regression\nBayesian ...,30566a34751c4ec895b4c0e52114251a
29,Classical models,Support Vector Machine (SVM) Algorithm,Kernel function_1,Kernel functions are generalized dot product f...,Classical models\nSupport Vector Machine (SVM)...,3605cac27cae713f70c7d7b55ea1229b
521,Metrics,Regression,Prediction Intervals in Forecasting: Quantile ...,by a factor of 0.25. The model will then try t...,Metrics\nRegression\nPrediction Intervals in F...,5506761035fc20aa868205f6b99a217e


In [20]:
candidates = df_db.iloc[indices]['text']

In [21]:
# Reranker:
reranker = CrossEncoder(reranker_model_name, device=device)

In [22]:
pairs = [(query, doc) for doc in candidates]
rerank_scores = reranker.predict(pairs)

In [23]:
rerank_scores

array([ 5.5456843, -6.027544 , -2.5661113, -5.963826 , -6.138963 ,
       -5.4645014, -2.6347544, -6.973322 , -7.245926 , -8.837734 ],
      dtype=float32)

In [24]:
# Sort by relevance:
reranked = [doc for _, doc in sorted(zip(rerank_scores, candidates), reverse=True)]

In [25]:
print(reranked[0])

Classical models
Linear Regression
Regression _1
Regression in machine learning refers to a supervised learning technique where the goal is to predict a continuous numerical value based on one or more independent features. It finds relationships between variables so that predictions can be made. we have two types of variables present in regression:
Dependent Variable (Target): The variable we are trying to predict e.g house price.
Independent Variables (Features): The input variables that influence the prediction e.g locality, number of rooms.
Regression analysis problem works with if output variable is a real or continuous value such as “salary” or “weight”. Many different regression models can be used but the simplest model in them is linear regression.
Types of Regression
Regression can be classified into different types based on the number of predictor variables and the nature of the relationship between variables:
1. Simple Linear Regression
Linear regression is one of the simples

Now let's compute Recall and MRR metric:

In [26]:
def get_recall_at_k(ids, gt):
    return float((ids == gt[:, np.newaxis]).any(axis=1).mean())

In [27]:
def get_mrr(ids, gt):
    matches = ids == gt[:, np.newaxis]
    ranks = np.argmax(matches, axis=1) + 1
    reciprocal_ranks = 1 / ranks
    reciprocal_ranks[~matches.any(axis=1)] = 0
    return float(np.mean(reciprocal_ranks))

In [28]:
N = 20

In [29]:
q_embeddings = bi_encoder.encode(questions, normalize_embeddings=True, show_progress_bar=True)

Batches: 100%|██████████| 101/101 [00:00<00:00, 159.94it/s]


We have a ground truth - numbers of relevant chunks

In [30]:
gt = df_q["chunk_id"].to_numpy()

In [32]:
def search(faiss_index, queries, top_N, batch_size=128):
    # finding closest to the query chunks in the index
    scores_list = []
    ids_list = []
    for i in tqdm.tqdm(range(0, len(q_embeddings), batch_size)):
        scores, ids = faiss_index.search(q_embeddings[i: i + batch_size], top_N)
        scores_list.append(scores)
        ids_list.append(ids)
    return np.concatenate(scores_list), np.concatenate(ids_list)

In [33]:
scores, ids = search(index, q_embeddings, top_N=N)

100%|██████████| 26/26 [00:00<00:00, 954.83it/s]


In [34]:
recall_at_k = get_recall_at_k(ids, gt)
recall_at_k

0.9748759305210918

In [35]:
mrr = get_mrr(ids, gt)
mrr

0.7934068797628261

MRR metric depends on the order of elements. Let's see if reranking can improve it

In [86]:
def get_reranked_ids(reranker, query, texts):
    pairs = [(query, doc) for doc in texts]
    rerank_scores = reranker.predict(pairs)
    return rerank_scores

In [97]:
def rerank_candidates(queries, ids):
    reranked = []
    for i, query in tqdm.tqdm( enumerate(queries)):
        candidate_ids = ids[i]
        candidate_texts = df_db.iloc[candidate_ids]['text'].tolist()
        rerank_scores = get_reranked_ids(reranker, query, candidate_texts)
        reranked.append([int(doc) for _, doc in sorted(zip(rerank_scores, candidate_ids), reverse=True)])
    return reranked

In [101]:
reranked = rerank_candidates(questions, ids)

3224it [03:45, 14.32it/s]


In [105]:
mrr_reranked = get_mrr(reranked, gt)
mrr_reranked

0.904644287032807

In [None]:
It shows that reranking improves MRR; 