In [1]:
import torch
import pandas as pd
import numpy as np
from typing import Literal

from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import BertTokenizer
import faiss

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.express as px

In [2]:
DATASET_NAME = 'sentence-transformers/squad'
RETRIEVER_NAME = 'multi-qa-mpnet-base-dot-v1'
RERANKER_NAME = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
dataset = load_dataset('sentence-transformers/squad')

questions = dataset['train']['question']
answers = list(set(dataset['train']['answer']))

In [4]:
retriever = SentenceTransformer(RETRIEVER_NAME).to(DEVICE)
reranker = CrossEncoder(RERANKER_NAME, max_length = 512, device = DEVICE)
tokenizer = BertTokenizer.from_pretrained(RERANKER_NAME, max_length = 512)

### PIPELINE

In [5]:
question_embeddings = retriever.encode(questions)
answer_embeddings = retriever.encode(answers)

In [23]:
def retrive(query: str, answers: list[str], answers_vec: np.array, top_k: int) -> list[dict]:
    retrive_list = list()
    
    query_vec = retriever.encode([query])[0]
    faiss_index = faiss.IndexFlatIP(answers_vec.shape[1])
    faiss_index.add(np.array(answers_vec, dtype = np.float32))
    distances, indices = faiss_index.search(np.array([query_vec], dtype=np.float32), top_k)
    
    for i, index in enumerate(indices[0]):
        retrive_dict = dict()
        
        distance = distances[0][i]
        retrive_dict['rank'] = i + 1
        retrive_dict['text'] = answers[index]
        retrive_dict['distance'] = distance
        retrive_dict['vector'] = answers_vec[index]
        retrive_list.append(retrive_dict)
        
    return retrive_list



def rerank(query: str, retrive_list: list[dict], top_k: int) -> list[dict]:
    retrive_answers = [dct['text'] for dct in retrive_list]
    reranking = reranker.rank(query, retrive_answers, top_k = top_k, return_documents = True)
    
    return reranking



def print_results(query, top_k_retriver, top_k_reranker, answers = answers, answer_embeddings = answer_embeddings):
    retrive_list = retrive(query, answers, answer_embeddings, top_k = top_k_retriver)
    reranker_list = rerank(query, retrive_list, top_k = top_k_reranker)
    
    print(f'Query: {query}\n')
    print(f"Real answer: {dataset['train']['answer'][questions.index(query)]}\n\n")
    print(f'Top {top_k_reranker} answers:\n')
    for i in range(top_k_reranker):
        print(f"Answer {i + 1}: {reranker_list[i]['text']}, {reranker_list[i]['score']}\n")

In [24]:
query = questions[0]
top_k_retriver = 1000
top_k_reranker = 10

print_results(query, top_k_retriver, top_k_reranker)

Query: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?

Real answer: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.


Top 10 answers:

Answer 1: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing i