In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import json
import pandas as pd
from pathlib import Path
from copy import deepcopy
from functools import partial

from bellem.musique.qa import answer_question_standard
from bellem.utils import set_seed, jprint
from bellem.musique.multihop import benchmark

set_seed(89)

In [3]:
from tqdm.auto import tqdm
tqdm.pandas()

# Dataset

In [4]:
from bellem.musique.constants import ABLATION_RECORD_IDS

df = pd.read_json('../../data/generated/musique-evaluation/dataset.jsonl', orient='records', lines=True)

# Take ablation subset
df = df.set_index('id', drop=False).loc[ABLATION_RECORD_IDS].copy().reset_index(drop=True)

# Load question decomposition
qd_df = pd.read_json('../../data/generated/musique-evaluation/question-decomposition.jsonl', orient='records', lines=True)
df = pd.merge(df.drop(columns=['question', 'question_decomposition']), qd_df, on='id', suffixes=('', ''))

print(len(df))
df.head()

100


Unnamed: 0,id,paragraphs,answer,answer_aliases,answerable,answers,question,question_decomposition
0,2hop__575188_342798,"[{'idx': 0, 'title': 'Liliana Mumy', 'paragrap...",Ahmad Shah Qajar,[Ahmad Shah Qajar],True,[Ahmad Shah Qajar],Who is the child of Mahmoud Mirza's father?,"[{'id': 575188, 'question': 'Who is Mahmoud Mi..."
1,2hop__731584_700117,"[{'idx': 0, 'title': 'KAPE', 'paragraph_text':...",Berrien County,[Berrien County],True,[Berrien County],In which county is the city to which KKVU is l...,"[{'id': 731584, 'question': 'To which city is ..."
2,2hop__690412_526810,"[{'idx': 0, 'title': 'Cabramatta Creek', 'para...",Chao Phraya River,[Chao Phraya River],True,[Chao Phraya River],For what river does the river on which Pa Sak ...,"[{'id': 690412, 'question': 'On which river is..."
3,2hop__263638_69048,"[{'idx': 0, 'title': 'Michael J. Barron', 'par...",Honorable Justice Abiodun Smith,[Honorable Justice Abiodun Smith],True,[Honorable Justice Abiodun Smith],Who is the Chief Judge of the Tebesa Nemine's ...,"[{'id': 263638, 'question': 'Where was Tebesa ..."
4,2hop__142842_68489,"[{'idx': 0, 'title': 'Perfect Night: Live in L...",Snapper Foster,[Snapper Foster],True,[Snapper Foster],Who did the performer of Night Rocker play on ...,"[{'id': 142842, 'question': 'Who performed Nig..."


In [5]:
df['paragraphs'] = df['paragraphs'].map(lambda ps: [p for p in ps if p['is_supporting']])

# Definitions

In [6]:
perfect_retrieval_func = lambda docs, query: [doc for doc in docs if doc['is_supporting']]

In [7]:
qa_func = answer_question_standard

In [8]:
from bellem.jerx.fewshot.llm import DEFAULT_JERX_SYSTEM_MESSAGE_FOR_LLAMA, DEFAULT_FEW_SHOT_EXAMPLE_MESSAGES
from bellem.jerx.fewshot.llm import make_kg_triplet_extract_fn

PREFIX_MESSAGES = [
    dict(role="system", content=DEFAULT_JERX_SYSTEM_MESSAGE_FOR_LLAMA),
    *DEFAULT_FEW_SHOT_EXAMPLE_MESSAGES,
]

In [9]:
default_completion_params = {
    "temperature": 0.1
}

def make_paragraph_replacer(model: str, prefix_messages: list[dict], completion_params: dict = default_completion_params):
    extract_kg_triplets = make_kg_triplet_extract_fn(model=model, prefix_messages=prefix_messages, completion_params=completion_params)

    def replace_paragraphs(row):
        new_paragraphs = []
        for p in row['paragraphs']:
            p = deepcopy(p) 
            triplets_str = '\n'.join(" | ".join(triplet) for triplet in extract_kg_triplets(p['paragraph_text']))
            p['paragraph_text'] = '\n'.join(["# Entity-relation-entity triplets", triplets_str])
            new_paragraphs.append(p)
        row['paragraphs'] = new_paragraphs
        return row
    
    return replace_paragraphs

# Experiments

In [10]:
N_RUNS = 3

results = []

## llama-zero-shot

In [11]:
rp_zs = make_paragraph_replacer('llama-3-8b-tgi', PREFIX_MESSAGES[:1])

for i in range(1, N_RUNS+1):
    df_llama_zs = df.progress_apply(rp_zs, axis=1) 
    df_llama_zs, scores = benchmark(df_llama_zs, qa_func, perfect_retrieval_func)
    results.append({**scores, "retrieval": "groundtruth", "context": "triplets", "jerx": "llama-zero-shot", "run": i})
    jprint(scores)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.5,
  "f1": 0.584095238095238,
  "fuzzy_match": 0.57
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.54,
  "f1": 0.6053809523809524,
  "fuzzy_match": 0.59
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.55,
  "f1": 0.6066839826839826,
  "fuzzy_match": 0.6
}


## llama-few-shot

In [12]:
rp_fs = make_paragraph_replacer('llama-3-8b-tgi', PREFIX_MESSAGES)

for i in range(1, N_RUNS+1):
    df_llama_fs = df.progress_apply(rp_fs, axis=1) 
    df_llama_fs, scores = benchmark(df_llama_fs, qa_func, perfect_retrieval_func, ignore_errors=True)
    results.append({**scores, "retrieval": "groundtruth", "context": "triplets", "jerx": "llama-few-shot", "run": i})
    jprint(scores)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.51,
  "f1": 0.6468936063936064,
  "fuzzy_match": 0.68
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.55,
  "f1": 0.6456158286158287,
  "fuzzy_match": 0.63
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.56,
  "f1": 0.63805772005772,
  "fuzzy_match": 0.65
}


## llama-sft

In [13]:
rp_sft_zs = make_paragraph_replacer('llama-3-8b-sft-tgi', PREFIX_MESSAGES[:1])

for i in range(1, N_RUNS+1):
    df_llama_sft_zs = df.progress_apply(rp_sft_zs, axis=1) 
    df_llama_sft_zs, scores = benchmark(df_llama_sft_zs, qa_func, perfect_retrieval_func, ignore_errors=True)
    results.append({**scores, "retrieval": "groundtruth", "context": "triplets", "jerx": "llama-sft-zs", "run": i})
    jprint(scores)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.45,
  "f1": 0.5350476190476189,
  "fuzzy_match": 0.54
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.46,
  "f1": 0.5345,
  "fuzzy_match": 0.53
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.47,
  "f1": 0.5353809523809523,
  "fuzzy_match": 0.54
}


In [14]:
rp_sft_fs = make_paragraph_replacer('llama-3-8b-sft-tgi', PREFIX_MESSAGES)

for i in range(1, N_RUNS+1):
    df_llama_sft_fs = df.progress_apply(rp_sft_fs, axis=1) 
    df_llama_sft_fs, scores = benchmark(df_llama_sft_fs, qa_func, perfect_retrieval_func, ignore_errors=True)
    results.append({**scores, "retrieval": "groundtruth", "context": "triplets", "jerx": "llama-sft-fs", "run": i})
    jprint(scores)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.45,
  "f1": 0.5087142857142857,
  "fuzzy_match": 0.51
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.43,
  "f1": 0.4956923076923077,
  "fuzzy_match": 0.48
}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{
  "exact_match": 0.41,
  "f1": 0.490047619047619,
  "fuzzy_match": 0.46
}


# Report

In [33]:
pd.options.display.float_format = '{:,.3f}'.format

In [34]:
report_df = pd.DataFrame.from_records(results, columns=['jerx', 'run', 'exact_match', 'f1'])
report_df.rename(columns={'jerx': 'Model', 'exact_match': 'EM', 'f1': 'F1'}, inplace=True)

In [35]:
model_name_map = {
    'llama-zero-shot': 'Zero-shot prompted',
    'llama-few-shot': 'Few-shot prompted',
    'llama-sft-zs': 'Supervised fine-tuned',
    'llama-sft-fs': 'Supervised fine-tuned (FS)',
}
report_df['Model'] = report_df['Model'].map(lambda x: model_name_map[x])

In [36]:
report_df

Unnamed: 0,Model,run,EM,F1
0,Zero-shot prompted,1,0.5,0.584
1,Zero-shot prompted,2,0.54,0.605
2,Zero-shot prompted,3,0.55,0.607
3,Few-shot prompted,1,0.51,0.647
4,Few-shot prompted,2,0.55,0.646
5,Few-shot prompted,3,0.56,0.638
6,Supervised fine-tuned,1,0.45,0.535
7,Supervised fine-tuned,2,0.46,0.534
8,Supervised fine-tuned,3,0.47,0.535
9,Supervised fine-tuned (FS),1,0.45,0.509


In [37]:
with open("ablation-jerx-llama-mhqa-results-all.tex", 'w') as f:
    f.write(report_df.to_latex(index=False, float_format='%.3f'))

In [40]:
agg_report_df = report_df[['Model', 'EM', 'F1']].groupby('Model').mean().loc[['Zero-shot prompted', 'Few-shot prompted', 'Supervised fine-tuned', 'Supervised fine-tuned (FS)']].reset_index()
agg_report_df

Unnamed: 0,Model,EM,F1
0,Zero-shot prompted,0.53,0.599
1,Few-shot prompted,0.54,0.644
2,Supervised fine-tuned,0.46,0.535
3,Supervised fine-tuned (FS),0.43,0.498


In [39]:
with open("ablation-jerx-llama-mhqa-results-agg.tex", 'w') as f:
    f.write(agg_report_df.to_latex(float_format='%.3f'))