# RAGs for Open Domain Complex QA
TU Delft, EEMCS, Natural Language Processing 2025, Group 30

## Contents

Contents of this notebook.

1. RAGs for Open Domain Complex QA
    - Contents
    - Setup
        - 1.1 Dependencies
        - 1.2 Imports
        - 1.3 Preparing the general dataset 
        - 1.4 Setting up the Llama model for QA
2. Experiments
    - 2.1 Experiment 1
    - 2.2 Experiment 2
    - 2.3 Experiment 3
    - 2.4 Experiment 4
    - 2.5 Experiment 5

3. ADORE
    - 1. Imports
    - 2. Load Data
    - 3. Fine-tuning query encoder using ADORE
    - 4. Evaluation of QA performance

## 1. Setup

First, install the dependencies. 

### 1.1 Dependencies

In [None]:
% uv venv
% source venv/bin/activate
% uv sync

### 1.2 Imports

In [None]:
# STD LIB
import os
import json
import random
import logging
from time import time
from pprint import pprint
from itertools import islice
from collections import defaultdict
from json.decoder import JSONDecodeError

# THIRD PARTY
from bert_score import score
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
from transformers import (
    AutoTokenizer,
    LlamaForCausalLM,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from tqdm.notebook import tqdm
from dexter.config.constants import Split
from dexter.data.loaders.RetrieverDataset import RetrieverDataset
from dexter.utils.metrics.SimilarityMatch import CosineSimilarity
from dexter.utils.metrics.retrieval.RetrievalMetrics import RetrievalMetrics
from dexter.data.datastructures.hyperparameters.dpr import DenseHyperParams

# LOCAL LIB
from utils import (
    AdoreRetriever,
    ContrieverRetriever,
    plot_accuracy_bar_chart,
    prepare_prompt,
    get_answer_from_model_output,
    exact_match_score,
    cover_exact_match_score,
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["HF_TOKEN"] = "<HF TOKEN HERE>"

config_path = "config.ini"

loader = RetrieverDataset(
    "wikimultihopqa", "wikimultihopqa-corpus", config_path, Split.DEV, tokenizer=None
)
queries, qrels, corpus = loader.qrels()

config_instance = DenseHyperParams(
    query_encoder_path="facebook/contriever",
    document_encoder_path="facebook/contriever",
    batch_size=32,
    show_progress_bar=True,
)

contrvr_search: ContrieverRetriever = ContrieverRetriever(
    config_instance, "indices", "index_1"
)
similarity_measure = CosineSimilarity()
retriever_response = contrvr_search.retrieve(
    corpus, queries, 100, similarity_measure, chunk=False, chunksize=400000
)

metrics = RetrievalMetrics(k_values=[1, 3, 5])  # Evaluate retrieval metrics
print(metrics.evaluate_retrieval(qrels=qrels, results=retriever_response))

### 1.3 Preparing the general dataset based on retrieval results
This dataset cotains the dev.json entries with questions and asnwers and all relevant documents as per the retriever. 
In combination with corpus_dict (use it for random samles or sth) this should be enough for the experiments up to adore

In [None]:
# Load data into kernel

dataset_dir = "data"
dev_path = f"{dataset_dir}/musique/dev.json"
corpus_path = f"{dataset_dir}/corpus/wiki_musique_corpus.json"

with open(dev_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

with open(corpus_path, "r", encoding="utf-8") as f:
    corpus_dict = json.load(f)

In [None]:
# Group dev data by _id

dev_dict = defaultdict(list)

for item in dev_data:
    if "_id" in item:
        _id = item["_id"]
        dev_dict[_id].append(item)
    else:
        print("Warning: JSON object missing '_id' field:", item)

In [None]:
# Create dataset from dev data and retrieved contexts

print("Creating dataset...")

dataset = []  # a list of dictionaries where each represnts a question from the dev set and related retrieved contexts
for dev_key, retrieved_contexts in islice(
    retriever_response.items(), len(retriever_response.items())
):
    outer_dict = defaultdict(
        list
    )  # the outer dictionary with dev_id, dev_full, context_list
    outer_dict["dev_id"] = dev_key
    dev_element = dev_dict[dev_key]
    outer_dict["dev_full"] = dev_element[0]

    context_list = []  # list of dictionaries where each dict has the following keys: context_id, context_score, context_full

    # contexts are ordered from the best to the worst, accoring to the used retriever
    sorted_contexts = sorted(
        retrieved_contexts.items(), key=lambda item: item[1], reverse=True
    )
    for context_key, context_score in sorted_contexts:
        context_dict = defaultdict(list)
        context_dict["context_id"] = context_key
        context_dict["context_score"] = context_score
        context_dict["context_full"] = corpus_dict[context_key]
        context_list.append(context_dict)

    outer_dict["context_list"] = context_list

    dataset.append(outer_dict)

print("Dataset created.")

### 1.4 Setting up the llama model for question answering (+ a usage example)

First, we set up the llama model for question answering, with a BitsAndBytesConfig.

In [None]:
# NOTE:
# For any hf model to work just use AutoModelForCausalLM instead of LLamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # bfloat16 for A6000
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = LlamaForCausalLM.from_pretrained(
    model_id, quantization_config=bnb_config, device_map="auto"
)

model = torch.compile(model)

Usage example:

In [None]:
text_generator = transformers.pipeline(
    "text-generation", model=model, tokenizer=tokenizer
)

joe_biden_prompt = prepare_prompt(
    question="Who directed movies called Example Movie?",
    evidences=[
        "Example Movie was directed by Joe Biden.",
        "Example Movies is a thriller.",
    ],
)

t1 = time()
response = text_generator(joe_biden_prompt, max_new_tokens=100)
t2 = time()

print(f"\nRESPONSE achieved in {t2 - t1}:.2f seconds...\n")
pprint(response)

In [None]:
answer: str = get_answer_from_model_output(response)
print(answer)

# Experiments

## 2.1. Experiment 1

In the cell below is the non multi threaded version

In [None]:
# Suppress specific logging
logging.getLogger("transformers").setLevel(logging.ERROR)
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)
bertscore_f1_scores = defaultdict(float)

precision_scores = []
recall_scores = []
f1_scores = []


elem_count = 0
count_parsing_errors = 0

counts = [5, 15]
print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
count_iters = 0

for element in dataset_copy:
    count_iters += 1
    elem_count += 1

    if count_iters == 100:
        t_curr = time()
        print(f"\nCompleted iterations: {elem_count}")
        print(f"Time ellapsed: {t_curr - t_start}\n")
        count_iters = 0

    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]

    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    for count in counts:
        t_count = time()
        evidences = []
        top_k_contexts = all_contexts[:count]

        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            context_text = context_full["text"]
            evidences.append(context_text)

        prompt = prepare_prompt(question, evidences)
        output = text_generator(prompt, max_new_tokens=50)

        try:
            model_answer = get_answer_from_model_output(output)
            if model_answer is None or correct_answer is None:
                print(
                    f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
                )
                continue

            EM_score = exact_match_score(model_answer, correct_answer)
            cover_score = cover_exact_match_score(model_answer, correct_answer)
            P, R, F1 = score([model_answer], [correct_answer], lang="en")
            bert_f1 = F1.mean().item()

            bertscore_f1_scores[count] += bert_f1
            exact_match_counts[count] += EM_score
            exact_match_cover_counts[count] += cover_score
        except JSONDecodeError:
            count_parsing_errors += 1

for count in counts:
    bertscore_f1_scores[count] /= elem_count
    exact_match_counts[count] /= elem_count
    exact_match_cover_counts[count] /= elem_count

pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start

print("Total elapsed time: ", total_time)

plot_accuracy_bar_chart(
    exact_match_counts,
    title="Exact Match Performance - Contriever Baseline",
    save_path="plots/contrv_baseline_exact_match_2.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_counts,
    title="Cover Exact Match Performance - Contriever Baseline",
    save_path="plots/contrv_baseline_cover_exact_match_2.png",
)
plot_accuracy_bar_chart(
    bertscore_f1_scores,
    title="Berstscore Performance - Contriever Baseline",
    save_path="plots/contrv_baseline_bertscore_2.png",
)


## 2.2. Experiment 2

In [None]:
def load_oracle_contexts(dev_data):
    oracle_contexts = {}
    for item in dev_data:
        question_id = item["_id"]
        supporting_facts = item.get("supporting_facts", [])
        contexts = item.get("context", [])

        filtered_contexts = []
        for fact in supporting_facts:
            fact_title, _ = fact
            for context_title, context_texts in contexts:
                if context_title == fact_title:
                    filtered_contexts.extend(context_texts)

        oracle_contexts[question_id] = filtered_contexts

    return oracle_contexts

In [None]:
oracle_contexts = load_oracle_contexts(dev_data)

oracle_dataset = []

for item in dev_data:
    outer_dict = defaultdict(list)
    outer_dict["dev_id"] = item["_id"]
    outer_dict["dev_full"] = item
    outer_dict["context_list"] = [
        {"text": context} for context in oracle_contexts[item["_id"]]
    ]
    oracle_dataset.append(outer_dict)

In [None]:
from itertools import islice
from pprint import pprint
from json.decoder import JSONDecodeError

oracle_exact_match_counts = defaultdict(int)
oracle_cover_match_counts = defaultdict(int)
oracle_bertscore_f1_scores = defaultdict(float)

elem_count = 0
count_parsing_errors = 0

t_start = time()
count_iters = 0

oracle_dataset_copy = oracle_dataset[:1200]
print(f"Starting run on {len(oracle_dataset_copy)} entries...")

for element in oracle_dataset_copy:
    count_iters += 1
    elem_count += 1

    if count_iters == 100:
        t_curr = time()
        print(f"\nCompleted iterations: {elem_count}")
        print(f"Time ellapsed: {t_curr - t_start}\n")
        count_iters = 0

    dev_id = element["dev_id"]
    question = element["dev_full"]["question"]
    correct_answer = element["dev_full"]["answer"]
    oracle_context_texts = [context["text"] for context in element["context_list"]]

    prompt = prepare_prompt(question, oracle_context_texts)
    output = text_generator(prompt, max_new_tokens=50)

    try:
        model_answer = get_answer_from_model_output(output)
        if model_answer is None or correct_answer is None:
            print(
                f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
            )
            continue

        EM_score = exact_match_score(model_answer, correct_answer)
        cover_score = cover_exact_match_score(model_answer, correct_answer)
        P, R, F1 = score([model_answer], [correct_answer], lang="en")
        bert_f1 = F1.mean().item()

        oracle_bertscore_f1_scores[1] += bert_f1
        oracle_exact_match_counts[1] += EM_score
        oracle_cover_match_counts[1] += cover_score

    except JSONDecodeError:
        # print(f'Error extracting answer for item {elem_count}')
        # pprint(output)
        count_parsing_errors += 1

oracle_bertscore_f1_scores[1] /= elem_count
oracle_exact_match_counts[1] /= elem_count
oracle_cover_match_counts[1] /= elem_count

pprint(oracle_exact_match_counts)
pprint(oracle_cover_match_counts)
print(f"Parsing errors: {count_parsing_errors}")
print("Total elapsed time:", time() - t_start)

plot_accuracy_bar_chart(
    oracle_exact_match_counts,
    title="Exact Match Performance - Oracle Contexts",
    save_path="plots/oracle_exact_match_2.png",
)
plot_accuracy_bar_chart(
    oracle_cover_match_counts,
    title="Cover Exact Match Performance - Oracle Contexts",
    save_path="plots/oracle_cover_exact_match_2.png",
)
plot_accuracy_bar_chart(
    oracle_bertscore_f1_scores,
    title="Bert Score Performance - Oracle Contexts",
    save_path="plots/oracle_bertscore_2.png",
)

## 2.3 Experiment 3

First let's try combining them in a 1:1 ratio

In [None]:
for key, value in islice(corpus_dict.items(), 1):
    print(key)
    pprint(value)

In [None]:
def sample_documents(corpus_dict, exclude_titles, n):
    filtered_documents = [
        doc["text"]
        for doc in corpus_dict.values()
        if doc["title"] not in exclude_titles
    ]
    n = min(n, len(filtered_documents))
    return random.sample(filtered_documents, n)


def halve_or_one(n):
    return 1 if n == 1 else n // 2

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)
bertscore_f1_scores = defaultdict(float)

elem_count = 0
count_parsing_errors = 0

counts = [1, 5, 15]

print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
count_iters = 0

for element in dataset_copy:
    count_iters += 1
    elem_count += 1

    if count_iters == 100:
        t_curr = time()
        print(f"\nCompleted iterations: {elem_count}")
        print(f"Time ellapsed: {t_curr - t_start}\n")
        count_iters = 0

    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]

    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    for count in counts:
        t_count = time()
        evidences = []
        top_k_contexts = all_contexts[:count]

        selected_contexts_titles = []
        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            selected_contexts_titles.append(context_title)
            context_text = context_full["text"]
            evidences.append(context_text)

        # sample random documents different than the selected ones and append to evidences
        sampled_docs = sample_documents(corpus_dict, selected_contexts_titles, count)
        evidences = [*evidences, *sampled_docs]

        prompt = prepare_prompt(question, evidences)
        output = text_generator(prompt, max_new_tokens=50)

        try:
            model_answer = get_answer_from_model_output(output)
            if model_answer is None or correct_answer is None:
                print(
                    f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
                )
                continue

            EM_score = exact_match_score(model_answer, correct_answer)
            cover_score = cover_exact_match_score(model_answer, correct_answer)
            P, R, F1 = score([model_answer], [correct_answer], lang="en")
            bert_f1 = F1.mean().item()

            bertscore_f1_scores[count] += bert_f1
            exact_match_counts[count] += EM_score
            exact_match_cover_counts[count] += cover_score
        except JSONDecodeError:
            count_parsing_errors += 1

for count in counts:
    bertscore_f1_scores[count] /= elem_count
    exact_match_counts[count] /= elem_count
    exact_match_cover_counts[count] /= elem_count

pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start

print("Total elapsed time: ", total_time)

plot_accuracy_bar_chart(
    exact_match_counts,
    title="Exact Match Performance - Contriever + Random Contexts (1:1 ratio)",
    save_path="plots/contrv_and_random_1_to_1_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_counts,
    title="Cover Exact Match Performance - Contriever + Random Contexts (1:1 ratio)",
    save_path="plots/contrv_and_random_1_to_1_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_f1_scores,
    title="Bertscore Performance - Contriever + Random Contexts (1:1 ratio)",
    save_path="plots/contrv_and_random_1_to_1_bertscore.png",
)

2:1 ratio + this should run a bit faster

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)
bertscore_f1_scores = defaultdict(float)

counts = [1, 5, 15]

print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()

queries = []
expected_answers = []
dev_ids = []
counts_mapping = []

# Prepare dataset
for element in dataset_copy:
    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]

    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    for count in counts:
        evidences = []
        top_k_contexts = all_contexts[:count]

        selected_contexts_titles = []
        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            selected_contexts_titles.append(context_title)
            context_text = context_full["text"]
            evidences.append(context_text)

        # Sample random documents different than the selected ones and append to evidences
        random_doc_num = halve_or_one(count)
        sampled_docs = sample_documents(
            corpus_dict, selected_contexts_titles, random_doc_num
        )
        evidences.extend(sampled_docs)

        # Prepare prompt
        prompt = prepare_prompt(question, evidences)

        queries.append(prompt)
        expected_answers.append(correct_answer)
        dev_ids.append(dev_id)
        counts_mapping.append(count)

# Generate responses in batch
outputs = text_generator(queries, max_new_tokens=50)

# Evaluate responses
count_parsing_errors = 0
for i, output in enumerate(outputs):
    count = counts_mapping[i]
    correct_answer = expected_answers[i]

    try:
        model_answer = get_answer_from_model_output(output)
        if model_answer is None or correct_answer is None:
            print(
                f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
            )
            continue

        EM_score = exact_match_score(model_answer, correct_answer)
        cover_score = cover_exact_match_score(model_answer, correct_answer)
        P, R, F1 = score([model_answer], [correct_answer], lang="en")
        bert_f1 = F1.mean().item()

        bertscore_f1_scores[count] += bert_f1
        exact_match_counts[count] += EM_score
        exact_match_cover_counts[count] += cover_score
    except JSONDecodeError:
        count_parsing_errors += 1

# Normalize scores
num_entries = len(dataset_copy)
for count in counts:
    bertscore_f1_scores[count] /= elem_count
    exact_match_counts[count] /= elem_count
    exact_match_cover_counts[count] /= elem_count

# Print results
pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start
print("Total elapsed time: ", total_time)

# Plot results
plot_accuracy_bar_chart(
    exact_match_counts,
    title="Exact Match Performance - Contriever + Random Contexts (2:1 ratio)",
    save_path="plots/contrv_and_random_2_to_1_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_counts,
    title="Cover Exact Match Performance - Contriever + Random Contexts (2:1 ratio)",
    save_path="plots/contrv_and_random_2_to_1_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_f1_scores,
    title="Bertscore Performance - Contriever + Random Contexts (2:1 ratio)",
    save_path="plots/contrv_and_random_2_to_1_bertscore.png",
)

## 2.4 Experiment 4: Hard Negatives

### 2.4.1 Helper Functions

In [None]:
from sentence_transformers import SentenceTransformer, util


def sample_hard_negative_documents(
    query, corpus_dict, exclude_titles, n, embedding_model
):
    """
    Sample top n hard negatives by similarity score but exclude ground truths.

    Parameters:
    - query: str, the query for which to sample hard negatives
    - corpus_dict: dict, where keys are titles and values are texts
    - exclude_titles: set or list, titles to exclude from consideration
    - n: int, number of hard negatives to sample
    - embedding_model: an embedding model with `.encode()` method to generate embeddings

    Returns:
    - List of top n hard negatives by similarity score (texts only)
    """
    # Filter documents to exclude ground truth titles
    filtered_documents = [
        text for title, text in corpus_dict.items() if title not in exclude_titles
    ]
    if len(filtered_documents) == 0:
        return []

    # Compute embeddings for the query and the candidate documents
    query_embedding = embedding_model.encode(query, convert_to_tensor=True)
    doc_embeddings = embedding_model.encode(filtered_documents, convert_to_tensor=True)

    # Compute cosine similarities
    similarities = util.cos_sim(query_embedding, doc_embeddings).squeeze()
    # print(similarities)
    # Sort documents by similarity in descending order
    sorted_indices = np.argsort(-similarities.cpu().numpy())
    top_n_indices = sorted_indices[: min(n, len(filtered_documents))]
    # print(similarities[top_n_indices[0]])
    # Select the top n hard negatives
    # print([filtered_documents[idx] for idx in sorted_indices[:5]])
    hard_negatives = [filtered_documents[idx] for idx in top_n_indices]
    return hard_negatives

### 2.4.2 Printing to clarify which documents to consider, what to take etc, understand

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)

counts = [1, 5, 15]

print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
print("len of corpus dict", len(corpus_dict))
print("Example of 1 corpus dict k,v pair")
iter = 1
for k, v in corpus_dict.items():
    # print(f"Key: {k}, Value: {v}\n\n")
    iter += 1
    break
print(corpus_dict["126271"])
queries = []
expected_answers = []
dev_ids = []
counts_mapping = []

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Prepare dataset
for element in dataset_copy:
    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]
    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]
    dev_ctxt_titles = []
    all_titles = []
    result_dict = defaultdict()
    # Union of all_contexts and dev_dict context into 1 dictionary
    # Process `all_contexts`
    for context in all_contexts:
        title = context["context_full"]["title"]
        text = context["context_full"]["text"]
        result_dict[title] = text
    # Process `dev_dict['context']`
    for entry in dev_dict["context"]:
        title = entry[0]
        text = " ".join(entry[1])  # Join the list of strings into a single text
        result_dict[title] = text
    ground_truth_titles = [elem[0] for elem in dev_dict["supporting_facts"]]
    for i in range(len(dev_dict["context"])):
        dev_ctxt_titles.append(dev_dict["context"][i][0])
    for context_item in all_contexts:
        all_titles.append(context_item["context_full"]["title"])
    print("Everything below is for 1 example, i.e, 1 of the 1200 samples.\n")
    element_keys = [k for k, v in element.items()]
    print("The keys for every element:", element_keys, "\n")
    print(
        "Length of the context list (superset, all_contexts=dataset_element[context_list]:",
        len(all_contexts),
    )
    print(
        "Example of one of the ",
        len(all_contexts),
        " all_contexts elements:\n",
        all_contexts[0],
        "\n",
    )
    dev_dict_keys = [k for k, v in dev_dict.items()]
    print("The keys in (dev_dict) dataset_element[dev_full]:", dev_dict_keys)
    print(
        "dev_dict[context] length and type:",
        len(dev_dict["context"]),
        type(dev_dict["context"]),
        "\n",
    )
    print("One example of dev_dict[context]:", dev_dict["context"][:2], "\n")
    print("List of all titles in dev_dict[context]", set(dev_ctxt_titles), "\n")
    print("The question we have is: ", question)
    print(
        "List of all elements of dev_dict[supporting_facts], considered to be the ground truth :",
        dev_dict["supporting_facts"],
    )
    print("List of ground truth titles: ", ground_truth_titles)

    for count in counts:
        evidences = []
        top_k_contexts = all_contexts[:count]

        selected_contexts_titles = []
        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            selected_contexts_titles.append(context_title)
            context_text = context_full["text"]
            evidences.append(context_text)

        # Sample random documents different than the selected ones and append to evidences
        # hard_neg_doc_num = halve_or_one(count)
        # sampled_docs = sample_documents(corpus_dict, selected_contexts_titles, random_doc_num)
        sampled_docs = sample_hard_negative_documents(
            question, result_dict, selected_contexts_titles, count, embedding_model
        )
        evidences.extend(sampled_docs)

        # Prepare prompt
        prompt = prepare_prompt(question, evidences)

        queries.append(prompt)
        expected_answers.append(correct_answer)
        dev_ids.append(dev_id)
        counts_mapping.append(count)

    iter = iter + 1
    if iter > 1:
        break
# print(queries)

### 2.4.3 Top-k contriever + hard_negatives from (all_contexts U dev_dict) (1:1)

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)
bertscore_f1_scores = defaultdict(float)

counts = [1, 5, 15]

print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
queries = []
expected_answers = []
dev_ids = []
counts_mapping = []

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Prepare dataset
iter = 1
for element in dataset_copy:
    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]
    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    ground_truth_titles = [elem[0] for elem in dev_dict["supporting_facts"]]

    for count in counts:
        evidences = []
        top_k_contexts = all_contexts[:count]
        selected_contexts_titles = []
        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            selected_contexts_titles.append(context_title)
            context_text = context_full["text"]
            evidences.append(context_text)

        # Sample the hard negatives
        num_hard_neg = len(selected_contexts_titles)
        selected_contexts_titles.extend(
            ground_truth_titles
        )  # Ground truth + top k titles
        sampled_docs = []
        hard_neg_so_far = 0
        # print(len(evidences), evidences)
        for context_item in all_contexts:
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            if context_title not in selected_contexts_titles:
                sampled_docs.append(context_full["text"])
                hard_neg_so_far += 1
                if hard_neg_so_far >= num_hard_neg:
                    break
        evidences.extend(sampled_docs)
        # print(len(sampled_docs),sampled_docs)

        # Prepare prompt
        prompt = prepare_prompt(question, evidences)

        queries.append(prompt)
        expected_answers.append(correct_answer)
        dev_ids.append(dev_id)
        counts_mapping.append(count)
    iter = iter + 1
    if iter % 400 == 0:
        print(iter)
# print(queries)
# Generate responses in batch
outputs = text_generator(queries, max_new_tokens=50)
print("outputs generated")
# Evaluate responses
count_parsing_errors = 0
for i, output in enumerate(outputs):
    count = counts_mapping[i]
    correct_answer = expected_answers[i]

    try:
        model_answer = get_answer_from_model_output(output)
        if model_answer is None or correct_answer is None:
            print(
                f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
            )
            continue

        EM_score = exact_match_score(model_answer, correct_answer)
        cover_score = cover_exact_match_score(model_answer, correct_answer)
        P, R, F1 = score([model_answer], [correct_answer], lang="en")
        bert_f1 = F1.mean().item()

        bertscore_f1_scores[count] += bert_f1
        exact_match_counts[count] += EM_score
        exact_match_cover_counts[count] += cover_score
    except JSONDecodeError:
        # print("Erroneous output",output)
        # print("Correct",correct_answer)
        # break
        count_parsing_errors += 1

# Normalize scores
num_entries = len(dataset_copy)
for count in counts:
    bertscore_f1_scores[count] /= elem_count
    exact_match_counts[count] /= elem_count
    exact_match_cover_counts[count] /= elem_count

# Print results
pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start
print("Total elapsed time: ", total_time)

# Plot results
plot_accuracy_bar_chart(
    exact_match_counts,
    title="Exact Match Performance - Contriever + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/contrv_and_neg_1_to_1_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_counts,
    title="Cover Exact Match Performance - Contriever + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/contrv_and_neg_1_to_1_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_f1_scores,
    title="Bertscore Performance - Contriever + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/contrv_and_neg_1_to_1_bertscore.png",
)


### 2.4.4 Ground truth (supporting facts only) + hard_negatives

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset


exact_match_counts = 0.0
exact_match_cover_counts = 0.0
bertscore_f1_scores = 0.0


print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
queries = []
expected_answers = []
dev_ids = []
counts_mapping = []

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
# Prepare dataset
iter = 1
for element in dataset_copy:
    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]
    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    ground_truth_titles = [elem[0] for elem in dev_dict["supporting_facts"]]
    evidences = []
    for context in dev_dict["context"]:
        if context[0] in ground_truth_titles:
            evidences.append(" ".join(context[1]))

    # Sample the hard negatives
    num_hard_neg = len(ground_truth_titles)
    sampled_docs = []
    hard_neg_so_far = 0
    # print(len(evidences), evidences)
    for context_item in all_contexts:
        context_full = context_item["context_full"]
        context_title = context_full["title"]
        if context_title not in ground_truth_titles:
            sampled_docs.append(context_full["text"])
            hard_neg_so_far += 1
            if hard_neg_so_far >= num_hard_neg:
                break

    evidences.extend(sampled_docs)
    # Prepare prompt
    prompt = prepare_prompt(question, evidences)

    queries.append(prompt)
    expected_answers.append(correct_answer)
    dev_ids.append(dev_id)
    # counts_mapping.append(count)
    iter = iter + 1
    if iter % 400 == 0:
        print(iter)
# Generate responses in batch
outputs = text_generator(queries, max_new_tokens=50)

# Evaluate responses
count_parsing_errors = 0
for i, output in enumerate(outputs):
    # count = counts_mapping[i]
    correct_answer = expected_answers[i]

    try:
        model_answer = get_answer_from_model_output(output)
        if model_answer is None or correct_answer is None:
            print(
                f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
            )
            continue

        EM_score = exact_match_score(model_answer, correct_answer)
        cover_score = cover_exact_match_score(model_answer, correct_answer)
        P, R, F1 = score([model_answer], [correct_answer], lang="en")
        bert_f1 = F1.mean().item()

        bertscore_f1_scores += bert_f1
        exact_match_counts += EM_score
        exact_match_cover_counts += cover_score
    except JSONDecodeError:
        count_parsing_errors += 1

# Normalize scores
num_entries = len(dataset_copy)
bertscore_f1_scores /= num_entries
exact_match_counts /= num_entries
exact_match_cover_counts /= num_entries

# Print results
pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start
print("Total elapsed time: ", total_time)
exact_match_dict = {"2*|ground truth contexts|": exact_match_counts}
exact_match_cover_dict = {"2*|ground truth contexts|": exact_match_cover_counts}
bertscore_dict = {"2*|ground truth contexts|": bertscore_f1_scores}
# Plot results
plot_accuracy_bar_chart(
    exact_match_dict,
    title="Exact Match Performance - ground truth + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/groundtruth_and_neg_1_to_1_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_dict,
    title="Cover Exact Match Performance - ground truth + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/groundtruth_and_neg_1_to_1_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_dict,
    title="Bertscore Performance - ground truth + Hard Negative Contexts (1:1 ratio)",
    save_path="plots/groundtruth_and_neg_1_to_1_bertscore.png",
)

## 2.5 Experiment 5 : Oracle + remaining dev contexts baseline

In [None]:
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset
exact_match_counts = 0.0
exact_match_cover_counts = 0.0
bertscore_f1_scores = 0.0

print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
queries = []
expected_answers = []
dev_ids = []
counts_mapping = []

# Prepare dataset
iter = 1
for element in dataset_copy:
    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]
    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    evidences = []
    selected_contexts_titles = []
    for context in dev_dict["context"]:
        selected_contexts_titles.append(context[0])
        evidences.append(" ".join(context[1]))

    # Prepare prompt
    prompt = prepare_prompt(question, evidences)

    queries.append(prompt)
    expected_answers.append(correct_answer)
    dev_ids.append(dev_id)
    # counts_mapping.append(count)
    iter = iter + 1
    if iter % 100 == 0:
        print(iter)
# print(queries)
# Generate responses in batch
outputs = text_generator(queries, max_new_tokens=50)

# Evaluate responses
count_parsing_errors = 0
for i, output in enumerate(outputs):
    # count = counts_mapping[i]
    correct_answer = expected_answers[i]

    try:
        model_answer = get_answer_from_model_output(output)
        if model_answer is None or correct_answer is None:
            print(
                f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
            )
            continue

        EM_score = exact_match_score(model_answer, correct_answer)
        cover_score = cover_exact_match_score(model_answer, correct_answer)
        P, R, F1 = score([model_answer], [correct_answer], lang="en")
        bert_f1 = F1.mean().item()

        bertscore_f1_scores += bert_f1
        exact_match_counts += EM_score
        exact_match_cover_counts += cover_score
    except JSONDecodeError:
        count_parsing_errors += 1

# Normalize scores
num_entries = len(dataset_copy)

bertscore_f1_scores /= num_entries
exact_match_counts /= num_entries
exact_match_cover_counts /= num_entries

# Print results
pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start
print("Total elapsed time: ", total_time)
exact_match_dict = {"|dev.json context list|": exact_match_counts}
exact_match_cover_dict = {"|dev.json context list|": exact_match_cover_counts}
bertscore_dict = {"|dev.json context list|": bertscore_f1_scores}
# Plot results
plot_accuracy_bar_chart(
    exact_match_dict,
    title="Exact Match Performance - Dev.json contexts",
    save_path="plots/dev_contexts_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_dict,
    title="Cover Exact Match Performance - Dev.json contexts",
    save_path="plots/dev_contexts_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_dict,
    title="Bertscore Performance - Dev.json contexts",
    save_path="plots/dev_contexts_bertscore.png",
)

# ADORE

We have data for dev test and training, called dev.json, test.json and train.json. they have formats: dict_keys(['_id', 'type', 'question', 'context', 'supporting_facts', 'evidences', 'answer']). We build an ADORE implementation in the `Retriever` base class, called the `AdoreRetriever`. Specifically, we train a retrieval model using the **ADORE method**. We are going to compare it to the above implementation of the `HfRetriever`. We are going to use the `HfRetriever` as a base, the changes will be specified later. We will use a different `self.question_encoder`, i.e. our own trained model. 


## 1. Imports

Make sure that either you ran all imports at the top of the notebook or you run the following cell.

In [None]:
# STD LIB
import os
import json
import itertools
from pprint import pprint
from itertools import islice
from collections import defaultdict


# THIRD PARTY
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer
from tqdm.notebook import tqdm
from dexter.config.constants import Split
from dexter.data.loaders.RetrieverDataset import RetrieverDataset
from dexter.utils.metrics.SimilarityMatch import CosineSimilarity
from dexter.utils.metrics.retrieval.RetrievalMetrics import RetrievalMetrics


# LOCAL LIB
from utils import AdoreRetriever, ContrieverRetriever

## 2. Load Data

We load corpus and training data into the runtime. The dev/train/test.json files have formats: `dict_keys(['_id', 'type', 'question', 'context', 'supporting_facts', 'evidences', 'answer'])`.

In [None]:
dataset_dir = "data"
dev_path = f"{dataset_dir}/musique/dev.json"
test_path = f"{dataset_dir}/musique/test.json"
train_path = f"{dataset_dir}/musique/train.json"
corpus_path = f"{dataset_dir}/corpus/wiki_musique_corpus.json"

with open(train_path, "r", encoding="utf-8") as f:
    train_data = json.load(f)

with open(corpus_path, "r", encoding="utf-8") as f:
    corpus_dict = json.load(f)

with open(dev_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

Here we prepare a dataset with train set ids and respective positives.

In [None]:
def load_positives_full(dev_data):
    oracle_contexts = {}
    for item in dev_data:
        question_id = item["_id"]
        supporting_facts = item.get("supporting_facts", [])
        contexts = item.get("context", [])

        filtered_contexts = []
        for fact in supporting_facts:
            fact_title, _ = fact
            for context_title, context_texts in contexts:
                if context_title == fact_title:
                    filtered_contexts.append(
                        {"title": context_title, "text": context_texts}
                    )

        q_c_dict = {"question": item["question"], "positives": filtered_contexts}
        oracle_contexts[question_id] = q_c_dict

    return oracle_contexts


positives_dict = load_positives_full(dev_data)

print(type(positives_dict))

for el in positives_dict:
    pprint(el)
    context_data = positives_dict[el]
    pprint(context_data)
    break

## 3. Fine-tuning the query encoder following the ADORE approach

Here, we train the query_encoder using the ADORE method and the training data. Reminder, we start with the same model as the `HfRetriever`, the `facebook/contriever` model. 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["HF_TOKEN"] = "<HF TOKEN HERE>"

### 3.1 ADORE training data formatting

To use the data for training the query encoder using ADORE, we first need a specific format.

In [None]:
def generate_positive_negative_pairs(contexts: list[dict]) -> list[dict]:
    """
    Given a list of contexts where each context is a dictionary containing:
        {
            "is_positive": bool,
            "text": str,
            "index": int,
            ... (other optional fields)
        },
    produce a list of all (positive, negative) context pairs WITHOUT duplicates.

    Each item in the returned list has the form:
        {
            "positive_text": str,
            "negative_text": str,
            "positive_index": int,
            "negative_index": int
        }

    'No duplicates' means we won't include the same text pair twice.
    You can change the definition of 'duplicate' if needed (for example,
    using the context 'index' instead of 'text').
    """

    # Separate contexts based on 'is_positive'
    positive_contexts = [c for c in contexts if c.get("is_positive", False)]
    negative_contexts = [c for c in contexts if not c.get("is_positive", False)]

    # We'll keep track of pairs we've already seen so we don't repeat them
    seen_pairs = set()
    results = []

    for pos_ctx in positive_contexts:
        for neg_ctx in negative_contexts:
            # Identify a pair by its texts (or by indices if that's preferable)
            pair_identifier = (pos_ctx["text"], neg_ctx["text"])

            # If we haven't seen this pair yet, add it to the results
            if pair_identifier not in seen_pairs:
                seen_pairs.add(pair_identifier)

                results.append(
                    {
                        "positive_text": pos_ctx["text"],
                        "negative_text": neg_ctx["text"],
                        "positive_index": pos_ctx["index"],
                        "negative_index": neg_ctx["index"],
                    }
                )

    return results


def prepare_positives_negatives_dataset(
    positives_dict: dict, retriever_response: dict, corpus_dict: dict
) -> list[dict]:
    training_instances = []
    for dev_key, retrieved_contexts in islice(
        retriever_response.items(), len(retriever_response.items())
    ):
        positives_dict_item = positives_dict[dev_key]
        question = positives_dict_item["question"]

        # key: str, value: str|list keys: text, content. for key=content, type(val)==list[str]
        positives = positives_dict_item["positives"]
        positives_titles = []
        for positive in positives:
            title = positive["title"]
            positives_titles.append(title)

        question_contexts = []

        sorted_contexts = sorted(
            retrieved_contexts.items(), key=lambda item: item[1], reverse=True
        )
        for i, (context_key, context_score) in enumerate(sorted_contexts):
            context_dict = defaultdict(list)
            context_dict["context_id"] = context_key
            context_dict["context_score"] = context_score
            context_full = corpus_dict[context_key]

            # print(context_score)
            # pprint(context_full)

            context_title = context_full["title"]
            context_text = context_full["text"]

            is_positive = context_title in positives_titles

            question_contexts.append(
                {
                    "is_positive": is_positive,
                    "index": i,
                    "title": context_title,
                    "text": context_text,
                }
            )

        training_instances.append({"question": question, "contexts": question_contexts})

    return training_instances

### 3.2 ADORE loss and helpers

First, some helper functions that we need to calculate the ADORE loss. These are for the Mean Average Precision (or MAP) and the Delta_MAP.

In [None]:
def calculate_map_at_k_using_is_positive(ctx_list, top_k=10):
    if not ctx_list:
        return 0.0

    average_precisions = []
    relevant_count = 0

    for i, c in enumerate(ctx_list):
        if i >= top_k:
            break
        if c.get("is_positive", False):
            relevant_count += 1
            precision_at_i = relevant_count / (i + 1)
            average_precisions.append(precision_at_i)

    if not average_precisions:
        return 0.0

    return sum(average_precisions) / len(average_precisions)


def calculate_delta_map_using_is_positive(contexts, idx1, idx2):
    max_index = len(contexts) - 1
    if not (0 <= idx1 <= max_index) or not (0 <= idx2 <= max_index):
        raise ValueError(
            f"Indices must be within [0..{max_index}], but got {idx1} and {idx2}."
        )

    original_map = calculate_map_at_k_using_is_positive(contexts, top_k=10)

    modified_contexts = contexts.copy()
    modified_contexts[idx1], modified_contexts[idx2] = (
        modified_contexts[idx2],
        modified_contexts[idx1],
    )

    swapped_map = calculate_map_at_k_using_is_positive(modified_contexts, top_k=10)

    Delta_M_abs_value = abs(swapped_map - original_map)
    return Delta_M_abs_value


And now here, the actual ADORE loss function.

In [None]:
def adore_loss(
    f_q_d_pos: torch.Tensor, f_q_d_neg: torch.Tensor, delta_M: torch.Tensor
) -> torch.Tensor:
    """
    ADORE loss from Eq. (15). (adjuste to usse MAP@10) Multiplies the standard pairwise RankNet loss
    by the magnitude of the change in MAP@10 (Delta_M).
    """
    return delta_M * torch.log(1 + torch.exp(f_q_d_neg - f_q_d_pos))

### 3.3. Retriever Loading

In [None]:
config_path = "config.ini"

loader = RetrieverDataset(
    "wikimultihopqa", "wikimultihopqa-corpus", config_path, Split.DEV, tokenizer=None
)
queries, qrels, corpus = loader.qrels()
print(f"Loader initialized with {len(queries)} queries and {len(corpus)} documents.")

adore_config = DenseHyperParams(  # use a patched config for adore retriever
    query_encoder_path="facebook/contriever",  # we replace the query encoder later by our own
    document_encoder_path="facebook/contriever",
    batch_size=32,
    show_progress_bar=True,
)

print("Adore config: ", adore_config)

# create the retriever instance, with the same config as the contriever.
adore_retriever: AdoreRetriever = (
    AdoreRetriever(  # NOTE: this is the same INDEX that we
        config=adore_config,  # used for the contriever. But, they use the same
        corpus_folder="indices",  # corpus, and the same document encoder.
        corpus_file="index_1",
    )
)

### 3.4 ADORE Training Loop

In [None]:
adore_query_encoder = AutoModel.from_pretrained(
    "facebook/contriever"
).cuda()  # We will finetune this query encoder !
adore_tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")

# training loop settings
batch_size = 1200
epochs = 10

# checkpoint settings
# save_every_n_batches = None # to save some disk space
save_every_n_batches = 1000
checkpoint_dir = "checkpoints2"  # define the dir for the checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)

TOP_K = 10
for epoch in tqdm(range(epochs), desc="Training epoch: "):
    # optimizer - values based on the original paper.
    optimizer = optim.AdamW(adore_query_encoder.parameters(), lr=5e-2)

    # Adore retriever has two separate encoder, we update the query encoder only
    retrieved_evidence = adore_retriever.retrieve(
        corpus=corpus,
        queries=queries,
        top_k=TOP_K,
        score_function=CosineSimilarity(),
        qrels=qrels,
    )

    train_data: list[dict[str, str | list[dict]]] = prepare_positives_negatives_dataset(
        positives_dict, retrieved_evidence, corpus_dict
    )

    print(f"Epoch {epoch + 1}/{epochs}")
    for i in tqdm(range(0, len(train_data), batch_size), desc="Training batch:"):
        batch = train_data[i : i + batch_size]

        # define arrays
        query_embeddings = []
        positive_embeddings = []
        negative_embeddings = []
        m_abs_values = []

        for elem in tqdm(batch, desc="Element in batch...: "):
            question = elem.get("question")

            # encode questions
            question_tokens = adore_tokenizer(
                question, padding=True, truncation=True, return_tensors="pt"
            ).to("cuda")

            query_encoding = adore_query_encoder(
                **question_tokens
            ).last_hidden_state.mean(dim=1)

            question_contexts = elem.get("contexts")
            pos_neg_pairs = generate_positive_negative_pairs(question_contexts)
            for pos_neg in pos_neg_pairs:
                pos_text = pos_neg.get("positive_text")
                neg_text = pos_neg.get("negative_text")

                # encode positive
                pos_tokens = adore_tokenizer(
                    pos_text, padding=True, truncation=True, return_tensors="pt"
                ).to("cuda")

                with torch.no_grad():
                    pos_encoding = adore_retriever.context_encoder(
                        **pos_tokens
                    ).last_hidden_state.mean(dim=1)

                # encode negative
                neg_tokens = adore_tokenizer(
                    neg_text, padding=True, truncation=True, return_tensors="pt"
                ).to("cuda")

                with torch.no_grad():
                    neg_encoding = adore_retriever.context_encoder(
                        **neg_tokens
                    ).last_hidden_state.mean(dim=1)

                pos_idx = pos_neg.get("positive_index")
                neg_idx = pos_neg.get("negative_index")
                M_abs = calculate_delta_map_using_is_positive(
                    contexts=question_contexts, idx1=pos_idx, idx2=neg_idx
                )

                query_embeddings.append(query_encoding)
                positive_embeddings.append(pos_encoding)
                negative_embeddings.append(neg_encoding)
                m_abs_values.append(M_abs)

        # Convert everything to tensors
        query_embeddings = torch.vstack(query_embeddings)
        positive_embeddings = torch.vstack(positive_embeddings)
        negative_embeddings = torch.vstack(negative_embeddings)
        relevance_score_deltas = torch.tensor(m_abs_values, dtype=torch.float32).to(
            "cuda"
        )

        query_embeddings.requires_grad_(True)
        positive_embeddings.requires_grad_(True)
        negative_embeddings.requires_grad_(True)
        relevance_score_deltas.requires_grad_(True)

        # Compute similarity scores
        f_q_d_pos = torch.cosine_similarity(
            query_embeddings, positive_embeddings, dim=1
        )
        f_q_d_neg = torch.cosine_similarity(
            query_embeddings, negative_embeddings, dim=1
        )

        f_q_d_pos.requires_grad_(True)
        f_q_d_neg.requires_grad_(True)

        loss = adore_loss(f_q_d_pos, f_q_d_neg, relevance_score_deltas).mean()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Batch {i // batch_size + 1}, Loss: {loss.item()}")

        if save_every_n_batches:
            if (i // batch_size) % save_every_n_batches == 0:
                checkpoint_path = (
                    f"{checkpoint_dir}/query_encoder_epoch{epoch + 1}_batch{i}.pt"
                )
                print(f"Saving checkpoint: {checkpoint_path}")
                adore_query_encoder.save_pretrained(checkpoint_path)
                adore_tokenizer.save_pretrained(checkpoint_path)

    # Save at the end of each epoch
    checkpoint_path = f"{checkpoint_dir}/query_encoder_epoch{epoch + 1}.pt"
    print(f"Saving epoch checkpoint: {checkpoint_path}")
    adore_query_encoder.save_pretrained(checkpoint_path)
    adore_tokenizer.save_pretrained(checkpoint_path)
    adore_retriever.set_query_encoder(checkpoint_path)
    adore_retriever.set_query_tokenizer(checkpoint_path)

## 4. Evaluation of QA Performance for ADORE

Here we evaluate the performance of our model, and compare it to the `ContrieverRetriever` model. Note: make sure that all imports from the earlier parts are still available. If not, please go back to the imports and run this cell. Also check for the Cuda availability and the `HF_TOKEN` environment variable.


### 4.1 Setup new retriever 

Also calculates some metrics for the new retriever at the end of the cell.

In [None]:
config_path = "config.ini"

loader = RetrieverDataset(
    "wikimultihopqa", "wikimultihopqa-corpus", config_path, Split.DEV, tokenizer=None
)

queries, qrels, corpus = loader.qrels()
print(f"Loader initialized with {len(queries)} queries and {len(corpus)} documents.")

adore_config = DenseHyperParams(  # use a patched config for adore retriever
    query_encoder_path="facebook/contriever",  # we replace the query encoder later by our own
    document_encoder_path="facebook/contriever",
    batch_size=32,
    show_progress_bar=True,
)

print("Adore config: ", adore_config)

# create the retriever instance, with the same config as the contriever
adore_retriever: AdoreRetriever = AdoreRetriever(
    config=adore_config, corpus_folder="indices/corpus", corpus_file="index_1"
)

checkpoint_path = "checkpoints/query_encoder_epoch50.pt"
# checkpoint_path = "checkpoints2/query_encoder_epoch2.pt"
adore_retriever.set_query_encoder(checkpoint_path)
adore_retriever.set_query_tokenizer(checkpoint_path)

adore_retriever_response = adore_retriever.retrieve(
    corpus=corpus,
    queries=queries,
    top_k=100,
    score_function=CosineSimilarity(),
    chunksize=400000,
)

metrics = RetrievalMetrics(k_values=[1, 3, 5])  # Evaluate retrieval metrics
print(metrics.evaluate_retrieval(qrels=qrels, results=adore_retriever_response))

# CONTRIEVER BASELINE:
# dev: ({'NDCG@1': 0.4225, 'NDCG@3': 0.33825, 'NDCG@5': 0.27994}, {'MAP@1': 0.0425, 'MAP@3': 0.07537, 'MAP@5': 0.08785}, {'Recall@1': 0.0425, 'Recall@3': 0.09387, 'Recall@5': 0.11961}, {'P@1': 0.4225, 'P@3': 0.31111, 'P@5': 0.23783})

# ADORE:
# 10 epochs, lr 5e-6, dev: ({'NDCG@1': 0.58583, 'NDCG@3': 0.45797, 'NDCG@5': 0.37441}, {'MAP@1': 0.05888, 'MAP@3': 0.10462, 'MAP@5': 0.12159}, {'Recall@1': 0.05888, 'Recall@3': 0.12597, 'Recall@5': 0.1572}, {'P@1': 0.58583, 'P@3': 0.4175, 'P@5': 0.31267})
# 5 epochs, lr 5e-4, dev: ({'NDCG@1': 0.6775, 'NDCG@3': 0.49926, 'NDCG@5': 0.39481}, {'MAP@1': 0.0681, 'MAP@3': 0.11416, 'MAP@5': 0.12831}, {'Recall@1': 0.0681, 'Recall@3': 0.13408, 'Recall@5': 0.15869}, {'P@1': 0.6775, 'P@3': 0.44472, 'P@5': 0.31567})
# 10 epochs, lr 5e-4, dev: ({'NDCG@1': 0.69833, 'NDCG@3': 0.51237, 'NDCG@5': 0.40646}, {'MAP@1': 0.07021, 'MAP@3': 0.1179, 'MAP@5': 0.13367}, {'Recall@1': 0.07021, 'Recall@3': 0.13757, 'Recall@5': 0.1637}, {'P@1': 0.69833, 'P@3': 0.45611, 'P@5': 0.32567})
# 50 epochs, lr 5e-4, dev: ({'NDCG@1': 0.69917, 'NDCG@3': 0.54798, 'NDCG@5': 0.42863}, {'MAP@1': 0.07027, 'MAP@3': 0.12902, 'MAP@5': 0.14332}, {'Recall@1': 0.07027, 'Recall@3': 0.15009, 'Recall@5': 0.17327}, {'P@1': 0.69917, 'P@3': 0.49778, 'P@5': 0.34483})

### 4.2 Some configuration

In [None]:
# NOTE:
# For any other hf model to work just use AutoModelForCausalLM instead of LLamaForCausalLM

model_id = "meta-llama/Llama-3.1-8B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = LlamaForCausalLM.from_pretrained(
    model_id, quantization_config=bnb_config, device_map="auto"
)

model = torch.compile(model)

In [None]:
# Reload the data
dataset_dir = "data"
dev_path = f"{dataset_dir}/musique/dev.json"
corpus_path = f"{dataset_dir}/corpus/wiki_musique_corpus.json"

with open(dev_path, "r", encoding="utf-8") as f:
    dev_data = json.load(f)

with open(corpus_path, "r", encoding="utf-8") as f:
    corpus_dict = json.load(f)

# set up the text generator
text_generator = transformers.pipeline(
    "text-generation", model=model, tokenizer=tokenizer
)

dev_dict = defaultdict(list)

for item in dev_data:
    if "_id" in item:
        _id = item["_id"]
        dev_dict[_id].append(item)
    else:
        print("Warning: JSON object missing '_id' field:", item)

### 4.3 Create dataset

In [None]:
print("creating dataset...")

dataset = []  # a list of dictionaries where each represnts a question from the dev set and related retrieved contexts
for dev_key, retrieved_contexts in islice(
    adore_retriever_response.items(), len(adore_retriever_response.items())
):
    outer_dict = defaultdict(
        list
    )  # the outer dictionary with dev_id, dev_full, context_list
    outer_dict["dev_id"] = dev_key
    dev_element = dev_dict[dev_key]
    outer_dict["dev_full"] = dev_element[0]

    context_list = []  # list of dictionaries where each dict has the following keys: context_id, context_score, context_full

    # contexts are ordered from the best to the worst, accoring to the used retriever
    sorted_contexts = sorted(
        retrieved_contexts.items(), key=lambda item: item[1], reverse=True
    )
    for context_key, context_score in sorted_contexts:
        context_dict = defaultdict(list)
        context_dict["context_id"] = context_key
        context_dict["context_score"] = context_score
        context_dict["context_full"] = corpus_dict[context_key]
        context_list.append(context_dict)

    outer_dict["context_list"] = context_list

    dataset.append(outer_dict)

print("dataset created")

### 4.4 Create plots


In [None]:
# Suppress specific logging
logging.getLogger("transformers").setLevel(logging.ERROR)
CHECKPOINT_FOLDER = "outputs"

dataset_copy = dataset

exact_match_counts = defaultdict(int)
exact_match_cover_counts = defaultdict(int)
bertscore_f1_scores = defaultdict(float)

precision_scores = []
recall_scores = []
f1_scores = []


elem_count = 0
count_parsing_errors = 0

counts = [1, 5, 15]
print(f"Starting run on {len(dataset_copy)} entries...")

t_start = time()
count_iters = 0

for element in dataset_copy:
    count_iters += 1
    elem_count += 1

    if count_iters == 100:
        t_curr = time()
        print(f"\nCompleted iterations: {elem_count}")
        print(f"Time ellapsed: {t_curr - t_start}\n")
        count_iters = 0

    dev_id = element["dev_id"]
    dev_dict = element["dev_full"]
    all_contexts = element["context_list"]

    question = dev_dict["question"]
    correct_answer = dev_dict["answer"]

    for count in counts:
        t_count = time()
        evidences = []
        top_k_contexts = all_contexts[:count]

        for context_item in top_k_contexts:
            context_score = context_item["context_score"]
            context_full = context_item["context_full"]
            context_title = context_full["title"]
            context_text = context_full["text"]
            evidences.append(context_text)

        prompt = prepare_prompt(question, evidences)
        output = text_generator(prompt, max_new_tokens=50)

        try:
            model_answer = get_answer_from_model_output(output)
            if model_answer is None or correct_answer is None:
                print(
                    f"Skipping invalid example. model_answer: {model_answer}, correct_answer: {correct_answer}"
                )
                continue

            EM_score = exact_match_score(model_answer, correct_answer)
            cover_score = cover_exact_match_score(model_answer, correct_answer)
            P, R, F1 = score([model_answer], [correct_answer], lang="en")
            bert_f1 = F1.mean().item()

            bertscore_f1_scores[count] += bert_f1
            exact_match_counts[count] += EM_score
            exact_match_cover_counts[count] += cover_score
        except JSONDecodeError:
            # print("\nExtracted answer was null!.")
            # pprint(output)
            count_parsing_errors += 1

for count in counts:
    bertscore_f1_scores[count] /= elem_count
    exact_match_counts[count] /= elem_count
    exact_match_cover_counts[count] /= elem_count

pprint(exact_match_counts)
pprint(exact_match_cover_counts)
pprint(bertscore_f1_scores)
print(f"decoding errors: {count_parsing_errors}")

t_end = time()
total_time = t_end - t_start

print("Total elapsed time: ", total_time)

plot_accuracy_bar_chart(
    exact_match_counts,
    title="Exact Match Performance - ADORE",
    save_path="plots/adore_exact_match.png",
)
plot_accuracy_bar_chart(
    exact_match_cover_counts,
    title="Cover Exact Match Performance - ADORE",
    save_path="plots/adore_cover_exact_match.png",
)
plot_accuracy_bar_chart(
    bertscore_f1_scores,
    title="Bert score Performance - ADORE",
    save_path="plots/adore_bertscore.png",
)
