In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import json
import pandas as pd
from pathlib import Path
from copy import deepcopy

from bellek.qa.llm import make_question_answer_func
from bellek.utils import set_seed, jprint
from bellek.musique.multihop import benchmark

set_seed(89)

In [3]:
def silence(exc_cls):
    def decorator(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except exc_cls as e:
                return None
        return wrapper
    return decorator

In [4]:
df = pd.read_json('../../data/generated/musique-evaluation/dataset.jsonl', orient='records', lines=True)
qd_df = pd.read_json('../../data/generated/musique-evaluation/question-decomposition.jsonl', orient='records', lines=True)
df = pd.merge(df.drop(columns=['question', 'question_decomposition']), qd_df, on='id', suffixes=('', ''))
df.head()

Unnamed: 0,id,paragraphs,answer,answer_aliases,answerable,answers,question,question_decomposition
0,2hop__131818_161450,"[{'idx': 0, 'title': 'Maria Carrillo High Scho...",in the north-east of the country south of the ...,"[Caspian Sea, in the north-east of the country...",True,"[Caspian Sea, in the north-east of the country...",Where is the Voshmgir District located?,"[{'id': 131818, 'question': 'In which oblast i..."
1,2hop__444265_82341,"[{'idx': 0, 'title': 'Ocala, Florida', 'paragr...",in Northern Florida,"[in Northern Florida, Northern Florida]",True,"[in Northern Florida, Northern Florida]",In what part of Florida is Tom Denney's birthp...,"[{'id': 444265, 'question': 'Where was Tom Den..."
2,2hop__711946_269414,"[{'idx': 0, 'title': 'Wild Thing (Tone Lōc son...",Kill Rock Stars,[Kill Rock Stars],True,[Kill Rock Stars],What record label is the performer who release...,"[{'id': 711946, 'question': 'Who is the perfor..."
3,2hop__311931_417706,"[{'idx': 0, 'title': 'The Main Attraction (alb...",Attic Records,"[Attic, Attic Records]",True,"[Attic, Attic Records]",What record label does the performer of Emotio...,"[{'id': 311931, 'question': 'Who performs Emot..."
4,2hop__809785_606637,"[{'idx': 0, 'title': 'The Main Attraction (alb...",Secret City Records,[Secret City Records],True,[Secret City Records],What record label does the performer of Advent...,"[{'id': 809785, 'question': 'Who is the perfor..."


In [5]:
jerx_file = Path("../../data/raw/musique-evaluation/jerx-inferences/llama3-base.jsonl")
jerx_df = pd.read_json(jerx_file, lines=True)
jerx_df.head()

Unnamed: 0,id,paragraph_idx,paragraph_text,paragraph_title,is_supporting,text,input,generation
0,2hop__131818_161450,0,Maria Carrillo High School is a public high sc...,Maria Carrillo High School,False,# Maria Carrillo High School\nMaria Carrillo H...,[{'content': 'You are an excellent knowledge g...,Maria Carrillo High School | location | Santa ...
1,2hop__131818_161450,1,"Golestān Province (Persian: استان گلستان‎, Ost...",Golestan Province,True,# Golestan Province\nGolestān Province (Persia...,[{'content': 'You are an excellent knowledge g...,Golestan Province | location | north-east of I...
2,2hop__131818_161450,2,Voshmgir District () is a district (bakhsh) in...,Voshmgir District,True,# Voshmgir District\nVoshmgir District () is a...,[{'content': 'You are an excellent knowledge g...,"Voshmgir District | location | Aqqala County, ..."
3,2hop__131818_161450,3,52 Heroor is a village in the southern state o...,52 Heroor,False,# 52 Heroor\n52 Heroor is a village in the sou...,[{'content': 'You are an excellent knowledge g...,"52 Heroor | location | Karnataka, India\n52 He..."
4,2hop__131818_161450,4,Vennaimalai is a village of Karur District loc...,Vennaimalai,False,# Vennaimalai\nVennaimalai is a village of Kar...,[{'content': 'You are an excellent knowledge g...,Vennaimalai | location | Karur District\nVenna...


In [6]:
jerx_mapping = {(row['id'], row['paragraph_idx']): row['generation'] for _, row in jerx_df.iterrows()}

def extract_triplets(example: dict):
    example["triplets_str"] = [jerx_mapping[(example['id'], p['idx'])].strip() for p in example['paragraphs']]
    return example

In [7]:
df = df.apply(extract_triplets, axis=1)
print(len(df))
df.head()

200


Unnamed: 0,id,paragraphs,answer,answer_aliases,answerable,answers,question,question_decomposition,triplets_str
0,2hop__131818_161450,"[{'idx': 0, 'title': 'Maria Carrillo High Scho...",in the north-east of the country south of the ...,"[Caspian Sea, in the north-east of the country...",True,"[Caspian Sea, in the north-east of the country...",Where is the Voshmgir District located?,"[{'id': 131818, 'question': 'In which oblast i...",[Maria Carrillo High School | location | Santa...
1,2hop__444265_82341,"[{'idx': 0, 'title': 'Ocala, Florida', 'paragr...",in Northern Florida,"[in Northern Florida, Northern Florida]",True,"[in Northern Florida, Northern Florida]",In what part of Florida is Tom Denney's birthp...,"[{'id': 444265, 'question': 'Where was Tom Den...",[Ocala | location | Florida\nOcala | location ...
2,2hop__711946_269414,"[{'idx': 0, 'title': 'Wild Thing (Tone Lōc son...",Kill Rock Stars,[Kill Rock Stars],True,[Kill Rock Stars],What record label is the performer who release...,"[{'id': 711946, 'question': 'Who is the perfor...",[Wild Thing | song by | Tone Lōc\nWild Thing |...
3,2hop__311931_417706,"[{'idx': 0, 'title': 'The Main Attraction (alb...",Attic Records,"[Attic, Attic Records]",True,"[Attic, Attic Records]",What record label does the performer of Emotio...,"[{'id': 311931, 'question': 'Who performs Emot...",[The Main Attraction | type | album\nThe Main ...
4,2hop__809785_606637,"[{'idx': 0, 'title': 'The Main Attraction (alb...",Secret City Records,[Secret City Records],True,[Secret City Records],What record label does the performer of Advent...,"[{'id': 809785, 'question': 'Who is the perfor...",[The Main Attraction (album) | artist | Grant ...


In [8]:
dummy_retrieval_func = lambda docs,query: docs
perfect_retrieval_func = lambda docs,query: [doc for doc in docs if doc['is_supporting']]

In [9]:
import bm25s

def bm25_retrieval(docs: list[dict], query: str, top_k: int = 10):
    top_k = min(top_k, len(docs))
    retriever = bm25s.BM25(corpus=docs)
    tokenized_corpus = bm25s.tokenize([doc['text'] for doc in docs])
    retriever.index(tokenized_corpus)
    results, _ = retriever.retrieve(bm25s.tokenize(query), k=top_k)
    return results[0].tolist()

In [10]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

def semantic_retrieval(docs: list[dict], query: str, top_k: int = 10):
    embeddings = model.encode([doc['text'] for doc in docs])
    query_vectors = model.encode([query])
    similarities = model.similarity(embeddings, query_vectors)
    sorted_indices = similarities.argsort(dim=0)
    return [docs[i] for i in sorted_indices[:top_k]]



In [11]:
completion_kwargs={"temperature": 0.0, "max_tokens": 1024}
qa_func = make_question_answer_func("gpt-3.5-turbo", completion_kwargs=completion_kwargs)

In [12]:
results = []

## Only paragraphs

In [13]:
_, scores = benchmark(df, qa_func, dummy_retrieval_func, ignore_errors=True)
results.append({**scores, "retrieval": "none", "context": "paragraphs"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.5,
  "f1": 0.6250534188034188,
  "fuzzy_match": 0.63
}


In [14]:
_, scores = benchmark(df, qa_func, perfect_retrieval_func, ignore_errors=True)
results.append({**scores, "retrieval": "groundtruth", "context": "paragraphs"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

Failed to answer the question 2hop__549146_223121
Unterminated string starting at: line 1 column 15 (char 14)
{
  "exact_match": 0.6,
  "f1": 0.7035177045177045,
  "fuzzy_match": 0.71
}


In [15]:
%%capture
_, scores = benchmark(df, qa_func, bm25_retrieval, ignore_errors=True)
results.append({**scores, "retrieval": "bm25", "context": "paragraphs"})
jprint(scores)

In [16]:
_, scores = benchmark(df, qa_func, semantic_retrieval, ignore_errors=True)
results.append({**scores, "retrieval": "semantic", "context": "paragraphs"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.09,
  "f1": 0.15600432900432903,
  "fuzzy_match": 0.125
}


## Paragraphs + Triplets

In [17]:
def enhance_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        p = deepcopy(p) 
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        p['paragraph_text'] = '\n'.join([p['paragraph_text'], "# Entity-relation-entity triplets", triplets_str])
        paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_paragraph_triplets = df.apply(enhance_paragraphs, axis=1) 
df_paragraph_triplets.head()
print(df_paragraph_triplets.iloc[0]['paragraphs'][2]['paragraph_text'])

Voshmgir District () is a district (bakhsh) in Aqqala County, Golestan Province, Iran. At the 2006 census, its population was 25,149, in 5,266 families. The District has one city: Anbar Olum. The District has two rural districts ("dehestan"): Mazraeh-ye Jonubi Rural District and Mazraeh-ye Shomali Rural District.
# Entity-relation-entity triplets
Voshmgir District | location | Aqqala County, Golestan Province, Iran
Voshmgir District | population | 25,149
Voshmgir District | population in families | 5,266
Voshmgir District | city | Anbar Olum
Voshmgir District | rural districts | Mazraeh-ye Jonubi Rural District, Mazraeh-ye Shomali Rural District


In [18]:
_, scores = benchmark(df_paragraph_triplets, qa_func, dummy_retrieval_func)
results.append({**scores, "retrieval": "none", "context": "paragraphs+triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.51,
  "f1": 0.6187155622155622,
  "fuzzy_match": 0.655
}


In [19]:
_, scores = benchmark(df_paragraph_triplets, qa_func, perfect_retrieval_func)
results.append({**scores, "retrieval": "groundtruth", "context": "paragraphs+triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.58,
  "f1": 0.6958762903762903,
  "fuzzy_match": 0.72
}


In [None]:
%%capture
_, scores = benchmark(df_paragraph_triplets, qa_func, bm25_retrieval)
results.append({**scores, "retrieval": "bm25", "context": "paragraphs+triplets"})
jprint(scores)

In [21]:
_, scores = benchmark(df_paragraph_triplets, qa_func, semantic_retrieval)
results.append({**scores, "retrieval": "semantic", "context": "paragraphs+triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.07,
  "f1": 0.13750000000000007,
  "fuzzy_match": 0.13
}


## Only triplets

In [22]:
def replace_paragraphs(row):
    paragraphs_with_triplets = []
    for p in row['paragraphs']:
        p = deepcopy(p) 
        triplets_str = str(jerx_mapping[(row['id'], p['idx'])])
        p['paragraph_text'] = '\n'.join(["# Entity-relation-entity triplets", triplets_str])
        paragraphs_with_triplets.append(p)
    row['paragraphs'] = paragraphs_with_triplets
    return row

df_only_triplets = df.apply(replace_paragraphs, axis=1) 
df_only_triplets.head()
print(df_only_triplets.iloc[0]['paragraphs'][2]['paragraph_text'])

# Entity-relation-entity triplets
Voshmgir District | location | Aqqala County, Golestan Province, Iran
Voshmgir District | population | 25,149
Voshmgir District | population in families | 5,266
Voshmgir District | city | Anbar Olum
Voshmgir District | rural districts | Mazraeh-ye Jonubi Rural District, Mazraeh-ye Shomali Rural District


In [23]:
_, scores = benchmark(df_only_triplets, qa_func, dummy_retrieval_func)
results.append({**scores, "retrieval": "none", "context": "triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.495,
  "f1": 0.5833064713064714,
  "fuzzy_match": 0.605
}


In [24]:
_, scores = benchmark(df_only_triplets, qa_func, perfect_retrieval_func)
results.append({**scores, "retrieval": "groundtruth", "context": "triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.54,
  "f1": 0.6704738594738595,
  "fuzzy_match": 0.685
}


In [25]:
%%capture
_, scores = benchmark(df_only_triplets, qa_func, bm25_retrieval)
results.append({**scores, "retrieval": "bm25", "context": "triplets"})
jprint(scores)

In [26]:
_, scores = benchmark(df_only_triplets, qa_func, semantic_retrieval)
results.append({**scores, "retrieval": "semantic", "context": "triplets"})
jprint(scores)

  0%|          | 0/200 [00:00<?, ?it/s]

{
  "exact_match": 0.12,
  "f1": 0.1835119047619047,
  "fuzzy_match": 0.175
}


# Report

In [27]:
report_df = pd.DataFrame.from_records(results, columns=['context', 'retrieval', 'exact_match', 'fuzzy_match', 'f1'])
report_df

Unnamed: 0,context,retrieval,exact_match,fuzzy_match,f1
0,paragraphs,none,0.5,0.63,0.625053
1,paragraphs,groundtruth,0.6,0.71,0.703518
2,paragraphs,bm25,0.46,0.58,0.586549
3,paragraphs,semantic,0.09,0.125,0.156004
4,paragraphs+triplets,none,0.51,0.655,0.618716
5,paragraphs+triplets,groundtruth,0.58,0.72,0.695876
6,paragraphs+triplets,bm25,0.48,0.615,0.599
7,paragraphs+triplets,semantic,0.07,0.13,0.1375
8,triplets,none,0.495,0.605,0.583306
9,triplets,groundtruth,0.54,0.685,0.670474


In [35]:
from datetime import datetime
suffix = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
report_df.to_json(f'../../data/generated/musique-evaluation/baseline-report-{suffix}.jsonl', orient='records', lines=True)

In [28]:
print(report_df[report_df['retrieval']=='none'].drop(columns=['retrieval']).to_markdown(index=False))

| context             |   exact_match |   fuzzy_match |       f1 |
|:--------------------|--------------:|--------------:|---------:|
| paragraphs          |         0.5   |         0.63  | 0.625053 |
| paragraphs+triplets |         0.51  |         0.655 | 0.618716 |
| triplets            |         0.495 |         0.605 | 0.583306 |


In [29]:
print(report_df[report_df['retrieval']=='groundtruth'].drop(columns=['retrieval']).to_markdown(index=False))

| context             |   exact_match |   fuzzy_match |       f1 |
|:--------------------|--------------:|--------------:|---------:|
| paragraphs          |          0.6  |         0.71  | 0.703518 |
| paragraphs+triplets |          0.58 |         0.72  | 0.695876 |
| triplets            |          0.54 |         0.685 | 0.670474 |


In [30]:
print(report_df[report_df['retrieval']=='bm25'].drop(columns=['retrieval']).to_markdown(index=False))

| context             |   exact_match |   fuzzy_match |       f1 |
|:--------------------|--------------:|--------------:|---------:|
| paragraphs          |         0.46  |         0.58  | 0.586549 |
| paragraphs+triplets |         0.48  |         0.615 | 0.599    |
| triplets            |         0.465 |         0.585 | 0.578074 |


In [31]:
print(report_df[report_df['retrieval']=='semantic'].drop(columns=['retrieval']).to_markdown(index=False))

| context             |   exact_match |   fuzzy_match |       f1 |
|:--------------------|--------------:|--------------:|---------:|
| paragraphs          |          0.09 |         0.125 | 0.156004 |
| paragraphs+triplets |          0.07 |         0.13  | 0.1375   |
| triplets            |          0.12 |         0.175 | 0.183512 |
