In [77]:
import torch
import pandas as pd
import numpy as np
from typing import Literal

from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import BertTokenizer
import faiss

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.express as px

In [78]:
DATASET_NAME = 'sentence-transformers/squad'
RETRIEVER_NAME = 'multi-qa-mpnet-base-dot-v1'
RERANKER_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [79]:
dataset = load_dataset(DATASET_NAME)

questions = dataset['train']['question']
answers = list(set(dataset['train']['answer']))

In [80]:
retriever = SentenceTransformer(RETRIEVER_NAME).to(DEVICE)
reranker = CrossEncoder(RERANKER_NAME, max_length = 512, device = DEVICE)
tokenizer = BertTokenizer.from_pretrained(RERANKER_NAME, max_length = 512)

### PIPELINE

In [81]:
questions_vec = retriever.encode(questions)
answers_vec = retriever.encode(answers)

In [82]:
faiss_index = faiss.IndexFlatIP(answers_vec.shape[1])
faiss_index.add(np.array(answers_vec, dtype = np.float32))
faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index)

def retrive(query: str, top_k: int, answers: list[str] = answers, answers_vec: np.array = answers_vec) -> list[dict]:
    retrive_list = list()
    pos = list()
    
    query_vec = retriever.encode([query])[0]
    distances, indices = faiss_index.search(np.array([query_vec], dtype=np.float32), top_k)
    
    for i, index in enumerate(indices[0]):
        retrive_dict = dict()
        
        distance = distances[0][i]
        retrive_dict['rank'] = i + 1
        retrive_dict['text'] = answers[index]
        retrive_dict['distance'] = distance
        retrive_dict['vector'] = answers_vec[index]
        retrive_list.append(retrive_dict)
        
    return retrive_list



def rerank(query: str, retrive_list: list[dict], top_k: int) -> list[dict]:
    retrive_answers = [dct['text'] for dct in retrive_list]
    reranking = reranker.rank(query, retrive_answers, top_k = top_k, return_documents = True)
    
    return reranking

        
        
def print_results(query, rerank_answers):
    print(f'Query: {query}\n')
    print(f"Real answer: {dataset['train']['answer'][questions.index(query)]}\n\n")
    print(f'Top {len(rerank_answers)} answers:\n')
    for i in range(len(rerank_answers)):
        print(f"Answer {i + 1}: {rerank_answers[i]['text']}, {rerank_answers[i]['score']}\n")            



def get_answer_rating(query, retrive_answers, rerank_answers):
    correct_answer = dataset['train']['answer'][questions.index(query)]
    
    retriver_idx = len(retrive_answers)
    for i, answer in enumerate([dct['text'] for dct in retrive_answers]):
        if correct_answer == answer:
            retriver_idx = i
            break
    
    rerank_idx = len(rerank_answers)
    for i, answer in enumerate([dct['text'] for dct in rerank_answers]):
        if correct_answer == answer:
            rerank_idx = i
            break
    
    return retriver_idx, rerank_idx



def get_metrics_base(queries, top_k_retriver, top_k_reranker):
    retriver_idxs, rerank_idxs = list(), list()
    
    for query in queries:
        retrive_list = retrive(query, top_k = top_k_retriver)
        rerank_list = rerank(query, retrive_list, top_k = top_k_reranker)
        
        retriver_idx, rerank_idx = get_answer_rating(query, retrive_list, rerank_list)
        retriver_idxs.append(retriver_idx)
        rerank_idxs.append(rerank_idx)
    
    return retriver_idxs, rerank_idxs


recall_k = lambda idxs, k: sum(1 for x in idxs if x <= k) / len(idxs) 
mrr = lambda idxs: np.mean([1 / (idx + 1) for idx in idxs])

In [91]:
retrive_idxs, rerank_idxs = get_metrics_base(questions[:10000], 50, 50)

In [92]:
print(f'Retriver:\t\trecall@1 = {recall_k(retrive_idxs, 0):.2f}\t\trecall@10 = {recall_k(retrive_idxs, 9):.2f}\t\trecall@50 = {recall_k(retrive_idxs, 49):.2f}')
print(f'Reranker:\t\trecall@1 = {recall_k(rerank_idxs, 0):.2f}\t\trecall@10 = {recall_k(rerank_idxs, 9):.2f}\t\trecall@50 = {recall_k(rerank_idxs, 49):.2f}')

Retriver:		recall@1 = 0.66		recall@10 = 0.94		recall@50 = 0.98
Reranker:		recall@1 = 0.81		recall@10 = 0.96		recall@50 = 0.98


In [93]:
print('Mean Reciprocal Rank')
print(f'Retriver: {mrr(retrive_idxs):.2f}')
print(f'Reranker: {mrr(rerank_idxs):.2f}')

Mean Reciprocal Rank
Retriver: 0.77
Reranker: 0.87
