In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
import json
import pandas as pd
from pathlib import Path
from copy import deepcopy
from functools import partial

from bellek.qa.ablation import answer_question_standard, answer_question_cte
from bellek.utils import set_seed, jprint
from bellek.musique.singlehop import benchmark as benchmark_single
from bellek.musique.multihop import benchmark as benchmark_multi

set_seed(89)

In [None]:
pd.options.display.float_format = '{:,.3f}'.format

In [None]:
from tqdm.auto import tqdm
tqdm.pandas()

In [None]:
N_RUNS = 2
SAMPLE_SIZE = 100

In [None]:
from bellek.musique.constants import ABLATION_RECORD_IDS

df = pd.read_json('../../data/generated/musique-evaluation/dataset.jsonl', orient='records', lines=True)
df.set_index('id', inplace=True, drop=False)
df = df.loc[ABLATION_RECORD_IDS].copy().reset_index(drop=True)
df = df.head(SAMPLE_SIZE)

print(len(df))
df.head()

In [None]:
qd_df = pd.read_json('../../data/generated/musique-evaluation/question-decomposition.jsonl', orient='records', lines=True)
df = pd.merge(df.drop(columns=['question', 'question_decomposition']), qd_df, on='id', suffixes=('', ''))
print(df.shape)
df.head()

In [None]:
jerx_file = Path("../../data/raw/musique-evaluation/jerx-inferences/llama3-base.jsonl")
jerx_df = pd.read_json(jerx_file, lines=True)
jerx_df.head()

In [None]:
jerx_mapping = {(row['id'], row['paragraph_idx']): row['generation'] for _, row in jerx_df.iterrows()}

def extract_triplets(example: dict):
    example["triplets_str"] = [jerx_mapping[(example['id'], p['idx'])].strip() for p in example['paragraphs']]
    return example

In [None]:
df = df.apply(extract_triplets, axis=1)
print(len(df))
df.head()

In [None]:
import bm25s
import logging

logging.getLogger("bm25s").setLevel(logging.ERROR)

def bm25_retrieval(docs: list[dict], query: str, top_k: int = 5):
    top_k = min(top_k, len(docs))
    retriever = bm25s.BM25(corpus=docs)
    tokenized_corpus = bm25s.tokenize([doc['text'] for doc in docs], show_progress=False)
    retriever.index(tokenized_corpus, show_progress=False)
    results, _ = retriever.retrieve(bm25s.tokenize(query), k=top_k, show_progress=False)
    return results[0].tolist()

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

def semantic_retrieval(docs: list[dict], query: str, top_k: int = 5):
    embeddings = model.encode([doc['text'] for doc in docs])
    query_vectors = model.encode([query])
    similarities = model.similarity(embeddings, query_vectors)
    sorted_indices = similarities.argsort(dim=0, descending=True)
    return [docs[i] for i in sorted_indices[:top_k]]

In [None]:
dummy_retrieval_func = lambda docs,query: docs
perfect_retrieval_func = lambda docs,query: [doc for doc in docs if doc['is_supporting']]

In [None]:
results = []

## Only paragraphs

In [None]:
for run in range(1, N_RUNS + 1):
    for qdecomp, benchmark in [(False, benchmark_single), (True, benchmark_multi)]:
        for qa_technique, qa_func in [('standard', answer_question_standard), ('cte', answer_question_cte)]:
            for top_k in [3, 5, 10]:
                for retriever_name, retriever in [
                        ('bm25', partial(bm25_retrieval, top_k=top_k)), 
                        ('semantic', partial(semantic_retrieval, top_k=top_k)), 
                    ]:
                    _, scores = benchmark(df, qa_func, retriever, ignore_errors=True)
                    results.append({**scores, "retrieval": retriever_name, "top_k": top_k, "context": "paragraphs", "qa": qa_technique, "qdecomp": qdecomp, "run": run})

## Paragraphs + Triplets

In [None]:
def enhance_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        p = deepcopy(p)
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        p['paragraph_text'] = '\n'.join([p['paragraph_text'], "# Entity-relation-entity triplets", triplets_str])
        paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_paragraph_triplets = df.apply(enhance_paragraphs, axis=1) 
df_paragraph_triplets.head()
print(df_paragraph_triplets.iloc[0]['paragraphs'][2]['paragraph_text'])

In [None]:
for run in range(1, N_RUNS + 1):
    for qdecomp, benchmark in [(False, benchmark_single), (True, benchmark_multi)]:
        for qa_technique, qa_func in [('standard', answer_question_standard)]:
            for top_k in [3, 5, 10]:
                for retriever_name, retriever in [('bm25', bm25_retrieval), ('semantic', semantic_retrieval)]:
                    _, scores = benchmark(df_paragraph_triplets, qa_func, retriever, ignore_errors=True)
                    results.append({**scores, "retrieval": retriever_name, "top_k": top_k, "context": "paragraphs+triplets", "qa": qa_technique, "qdecomp": qdecomp, "run": run})

## Only triplets

In [None]:
def replace_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        for triplet in triplets_str.splitlines():
            p = deepcopy(p) 
            p['title'] = ""
            p['paragraph_text'] = triplet.strip()
            paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_only_triplets = df.apply(replace_paragraphs, axis=1) 
df_only_triplets.head()
print(df_only_triplets.iloc[0]['paragraphs'][0]['paragraph_text'])

In [None]:
for run in range(1, N_RUNS + 1):
    for qdecomp, benchmark in [(False, benchmark_single), (True, benchmark_multi)]:
        for qa_technique, qa_func in [('standard', answer_question_standard)]:
            for top_k in [3, 5]:
                top_k_effective = top_k*3 if qdecomp else top_k*5
                for retriever_name, retriever in [
                        ('bm25', partial(bm25_retrieval, top_k=top_k_effective)), 
                        ('semantic', partial(semantic_retrieval, top_k=top_k_effective)), 
                    ]:
                    _, scores = benchmark(df_only_triplets, qa_func, retriever, ignore_errors=True)
                    results.append({**scores, "retrieval": retriever_name, "top_k": top_k_effective, "context": "triplets", "qa": qa_technique, "qdecomp": qdecomp, "run": run})

# Report

In [None]:
report_df = pd.DataFrame.from_records(results, columns=['qdecomp', 'context', 'retrieval', 'top_k', 'qa', 'run', 'exact_match', 'f1'])
report_df.sort_values('exact_match', ascending=False)

In [None]:
from datetime import datetime
suffix = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
report_df.to_json(f'./our-method-report-{suffix}.jsonl', orient='records', lines=True)

## Retrieval impact

In [None]:
print(report_df[report_df['context']=='paragraphs'].to_markdown(index=False))