In [None]:
import re
import string
import datasets
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from transformers import pipeline
from ragatouille import RAGPretrainedModel
from transformers import Pipeline
from typing import Optional, List, Tuple
from langchain.docstore.document import Document as LangchainDocument

pd.set_option("display.max_columns", None)

In [None]:
from huggingface_hub import interpreter_login
interpreter_login()

In [None]:
EMBEDDING_MODEL_NAME = "naver/splade-v3"
RERANKER_MODEL_NAME = "colbert-ir/colbertv2.0"
READER_MODEL_NAME = "naver/cocom-v1-128-mistral-7b" # max context length: 8192
READER_TOKENIZER = "mistralai/Mistral-7B-Instruct-v0.2"

load vector database

In [None]:
embedding_model = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL_NAME,
    multi_process=True,
    model_kwargs={"device": "cuda"},
    encode_kwargs={"normalize_embeddings": True},  # Set `True` for cosine similarity
)

In [None]:
vdb = FAISS.load_local("/group/jmearlesgrp/data/KILT/faiss_dbs/faiss_index_large", embedding_model, allow_dangerous_deserialization=True)

load reader model

In [None]:
reader_model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to("cuda")

initialize reranker

In [None]:
reranker = RAGPretrainedModel.from_pretrained(RERANKER_MODEL_NAME)

answer and evaluate

In [None]:
def answer_with_rag(
    question: str,
    llm,
    knowledge_index: FAISS,
    reranker: Optional[RAGPretrainedModel] = None,
    num_retrieved_docs: int = 30,
    num_docs_final: int = 5,
) -> Tuple[str, List[LangchainDocument]]:
    # Gather documents with retriever
    print("=> Retrieving documents...")
    relevant_docs = knowledge_index.similarity_search(query=question, k=num_retrieved_docs)
    relevant_docs = [doc.page_content for doc in relevant_docs]  # Keep only the text

    # Optionally rerank results
    if reranker:
        print("=> Reranking documents...")
        relevant_docs = reranker.rerank(question, relevant_docs, k=num_docs_final)
        relevant_docs = [doc["content"] for doc in relevant_docs]

    relevant_docs = relevant_docs[:num_docs_final]

    # Redact an answer
    print("=> Generating answer...")
    answers = llm.generate_from_text(contexts=[relevant_docs], questions=[question], max_new_tokens=128)

    return answers[0], relevant_docs

def normalize_answer(s):
    """Lower text and remove punctuation, articles, and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_ground_truth(example):
    """
    Extract the ground truth answers from the example.
    If the 'answer' field is a dict with an 'aliases' key,
    return the full list of aliases.
    Otherwise, return the answer as a list.
    """
    gt = example['output'][0].get("answer", "")
    if isinstance(gt, dict) and "aliases" in gt:
        return gt["aliases"]
    elif isinstance(gt, list):
        return gt
    else:
        return [gt]

def exact_match_score(prediction, ground_truth_aliases):
    """
    Returns 1 if the normalized prediction exactly matches any alias in
    ground_truth_aliases, otherwise returns 0.
    """
    norm_pred = normalize_answer(prediction)
    for alias in ground_truth_aliases:
        if norm_pred == normalize_answer(alias):
            return 1
    return 0

In [None]:
# dataset to evaluate
dataset = datasets.load_dataset("facebook/kilt_tasks", "hotpotqa", split="validation")

In [None]:
total_em = 0
total = 0

for example in tqdm(dataset, desc="Evaluating EM"):
    question = example.get("input", "")
    ground_truth_aliases = get_ground_truth(example)  # full list of aliases

    prediction, used_docs = answer_with_rag(
        question,
        reader_model,
        vdb,
        reranker=reranker,
        num_retrieved_docs=30,
        num_docs_final=5,
    )
    
    print(f"Question: {question}")
    print(f"Prediction: {prediction}")
    print(f"Ground Truth Aliases: {ground_truth_aliases}")
    print(f"Used Documents: {used_docs}")
    total_em += exact_match_score(prediction, ground_truth_aliases)
    total += 1
    
    break

em_score = total_em / total if total > 0 else 0
print(f"Exact Match (EM) score: {em_score:.2f}")