In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import json
import pandas as pd
from pathlib import Path
from copy import deepcopy

from bellek.mhqa.llm import make_question_answer_func
from bellek.utils import set_seed, jprint
from bellek.musique.baseline import benchmark

set_seed(89)

In [3]:
def silence(exc_cls):
    def decorator(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except exc_cls as e:
                return None
        return wrapper
    return decorator

In [4]:
df = pd.read_json('../../data/generated/musique-evaluation/dataset.jsonl', orient='records', lines=True)

In [5]:
jerx_file = Path("../../data/raw/musique-evaluation/jerx-inferences/llama3-base.jsonl")
jerx_df = pd.read_json(jerx_file, lines=True)
jerx_df.head()

Unnamed: 0,id,paragraph_idx,paragraph_text,paragraph_title,is_supporting,text,input,generation
0,2hop__131818_161450,0,Maria Carrillo High School is a public high sc...,Maria Carrillo High School,False,# Maria Carrillo High School\nMaria Carrillo H...,[{'content': 'You are an excellent knowledge g...,Maria Carrillo High School | location | Santa ...
1,2hop__131818_161450,1,"Golestān Province (Persian: استان گلستان‎, Ost...",Golestan Province,True,# Golestan Province\nGolestān Province (Persia...,[{'content': 'You are an excellent knowledge g...,Golestan Province | location | north-east of I...
2,2hop__131818_161450,2,Voshmgir District () is a district (bakhsh) in...,Voshmgir District,True,# Voshmgir District\nVoshmgir District () is a...,[{'content': 'You are an excellent knowledge g...,"Voshmgir District | location | Aqqala County, ..."
3,2hop__131818_161450,3,52 Heroor is a village in the southern state o...,52 Heroor,False,# 52 Heroor\n52 Heroor is a village in the sou...,[{'content': 'You are an excellent knowledge g...,"52 Heroor | location | Karnataka, India\n52 He..."
4,2hop__131818_161450,4,Vennaimalai is a village of Karur District loc...,Vennaimalai,False,# Vennaimalai\nVennaimalai is a village of Kar...,[{'content': 'You are an excellent knowledge g...,Vennaimalai | location | Karur District\nVenna...


In [6]:
jerx_mapping = {(row['id'], row['paragraph_idx']): row['generation'] for _, row in jerx_df.iterrows()}

def extract_triplets(example: dict):
    example["triplets_str"] = [jerx_mapping[(example['id'], p['idx'])].strip() for p in example['paragraphs']]
    return example

In [7]:
df = df.apply(extract_triplets, axis=1)
print(len(df))
df.head()

200


Unnamed: 0,id,paragraphs,question,question_decomposition,answer,answer_aliases,answerable,answers,triplets_str
0,2hop__131818_161450,"[{'idx': 0, 'title': 'Maria Carrillo High Scho...",Where is the Voshmgir District located?,"[{'id': 131818, 'question': 'Which state is Vo...",in the north-east of the country south of the ...,"[Caspian Sea, in the north-east of the country...",True,"[Caspian Sea, in the north-east of the country...",[Maria Carrillo High School | location | Santa...
1,2hop__444265_82341,"[{'idx': 0, 'title': 'Ocala, Florida', 'paragr...",In what part of Florida is Tom Denney's birthp...,"[{'id': 444265, 'question': 'Tom Denney >> pla...",in Northern Florida,"[in Northern Florida, Northern Florida]",True,"[in Northern Florida, Northern Florida]",[Ocala | location | Florida\nOcala | location ...
2,2hop__711946_269414,"[{'idx': 0, 'title': 'Wild Thing (Tone Lōc son...",What record label is the performer who release...,"[{'id': 711946, 'question': 'All Your Faded Th...",Kill Rock Stars,[Kill Rock Stars],True,[Kill Rock Stars],[Wild Thing | song by | Tone Lōc\nWild Thing |...
3,2hop__311931_417706,"[{'idx': 0, 'title': 'The Main Attraction (alb...",What record label does the performer of Emotio...,"[{'id': 311931, 'question': 'Emotional Rain >>...",Attic Records,"[Attic, Attic Records]",True,"[Attic, Attic Records]",[The Main Attraction | type | album\nThe Main ...
4,2hop__809785_606637,"[{'idx': 0, 'title': 'The Main Attraction (alb...",What record label does the performer of Advent...,"[{'id': 809785, 'question': 'Adventures in You...",Secret City Records,[Secret City Records],True,[Secret City Records],[The Main Attraction (album) | artist | Grant ...


In [8]:
COMPLETION_KWARGS={"temperature": 0.0, "max_tokens": 2048}
qa_func = make_question_answer_func("gpt-3.5-turbo", completion_kwargs=COMPLETION_KWARGS)

In [9]:
results = []

## Only paragraphs

In [10]:
_, scores = benchmark(df, qa_func, only_supporting=False)
results.append({**scores, "retrieval": "none", "context": "paragraphs"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.5,
  "f1": 0.5822423687423688,
  "fuzzy_match": 0.595
}


In [11]:
_, scores = benchmark(df, qa_func, only_supporting=True)
results.append({**scores, "retrieval": "groundtruth", "context": "paragraphs"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.615,
  "f1": 0.7115097680097681,
  "fuzzy_match": 0.71
}


## Paragraphs + Triplets

In [12]:
def enhance_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        p = deepcopy(p) 
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        p['paragraph_text'] = '\n'.join([p['paragraph_text'], "# Entity-relation-entity triplets", triplets_str])
        paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_paragraph_triplets = df.apply(enhance_paragraphs, axis=1) 
df_paragraph_triplets.head()
print(df_paragraph_triplets.iloc[0]['paragraphs'][2]['paragraph_text'])

Voshmgir District () is a district (bakhsh) in Aqqala County, Golestan Province, Iran. At the 2006 census, its population was 25,149, in 5,266 families. The District has one city: Anbar Olum. The District has two rural districts ("dehestan"): Mazraeh-ye Jonubi Rural District and Mazraeh-ye Shomali Rural District.
# Entity-relation-entity triplets
Voshmgir District | location | Aqqala County, Golestan Province, Iran
Voshmgir District | population | 25,149
Voshmgir District | population in families | 5,266
Voshmgir District | city | Anbar Olum
Voshmgir District | rural districts | Mazraeh-ye Jonubi Rural District, Mazraeh-ye Shomali Rural District


In [13]:
_, scores = benchmark(df_paragraph_triplets, qa_func, only_supporting=False)
results.append({**scores, "retrieval": "none", "context": "paragraphs+triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.47,
  "f1": 0.5765849705849706,
  "fuzzy_match": 0.58
}


In [14]:
_, scores = benchmark(df_paragraph_triplets, qa_func, only_supporting=True)
results.append({**scores, "retrieval": "groundtruth", "context": "paragraphs+triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.58,
  "f1": 0.7009486069486072,
  "fuzzy_match": 0.71
}


## Only triplets

In [15]:
def replace_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        p = deepcopy(p) 
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        p['paragraph_text'] = '\n'.join(["# Entity-relation-entity triplets", triplets_str])
        paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_only_triplets = df.apply(replace_paragraphs, axis=1) 
df_only_triplets.head()
print(df_only_triplets.iloc[0]['paragraphs'][2]['paragraph_text'])

# Entity-relation-entity triplets
Voshmgir District | location | Aqqala County, Golestan Province, Iran
Voshmgir District | population | 25,149
Voshmgir District | population in families | 5,266
Voshmgir District | city | Anbar Olum
Voshmgir District | rural districts | Mazraeh-ye Jonubi Rural District, Mazraeh-ye Shomali Rural District


In [16]:
_, scores = benchmark(df_only_triplets, qa_func, only_supporting=False, ignore_errors=True)
results.append({**scores, "retrieval": "none", "context": "triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.425,
  "f1": 0.5433461538461539,
  "fuzzy_match": 0.55
}


In [17]:
_, scores = benchmark(df_only_triplets, qa_func, only_supporting=True, ignore_errors=True)
results.append({**scores, "retrieval": "groundtruth", "context": "triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

Failed to answer the question 2hop__199513_13732
Unterminated string starting at: line 1 column 15 (char 14)
{
  "exact_match": 0.55,
  "f1": 0.6674594294594296,
  "fuzzy_match": 0.67
}


# Report

In [18]:
report_df = pd.DataFrame.from_records(results, columns=['context', 'retrieval', 'exact_match', 'fuzzy_match', 'f1'])
report_df

Unnamed: 0,context,retrieval,exact_match,fuzzy_match,f1
0,paragraphs,none,0.5,0.595,0.582242
1,paragraphs,groundtruth,0.615,0.71,0.71151
2,paragraphs+triplets,none,0.47,0.58,0.576585
3,paragraphs+triplets,groundtruth,0.58,0.71,0.700949
4,triplets,none,0.425,0.55,0.543346
5,triplets,groundtruth,0.55,0.67,0.667459


In [19]:
print(report_df[report_df['retrieval']=='none'].to_markdown(index=False))

| context             | retrieval   |   exact_match |   fuzzy_match |       f1 |
|:--------------------|:------------|--------------:|--------------:|---------:|
| paragraphs          | none        |         0.5   |         0.595 | 0.582242 |
| paragraphs+triplets | none        |         0.47  |         0.58  | 0.576585 |
| triplets            | none        |         0.425 |         0.55  | 0.543346 |


In [20]:
print(report_df[report_df['retrieval']=='groundtruth'].to_markdown(index=False))

| context             | retrieval   |   exact_match |   fuzzy_match |       f1 |
|:--------------------|:------------|--------------:|--------------:|---------:|
| paragraphs          | groundtruth |         0.615 |          0.71 | 0.71151  |
| paragraphs+triplets | groundtruth |         0.58  |          0.71 | 0.700949 |
| triplets            | groundtruth |         0.55  |          0.67 | 0.667459 |
