# MuSiQue baseline

In [None]:
#|default_exp musique.baseline

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export
from typing import Callable
import pandas as pd
from bellek.jerx.reward.llm import make_question_answer_func, QuestionAnsweringResult
from bellek.musique.eval import calculate_metrics, compare_answers

In [None]:
#|export

def make_docs(example, only_supporting=False):
    ps = example["paragraphs"]
    for p in ps:
        if only_supporting and not p["is_supporting"]:
            continue
        idx = p["idx"]
        title = p["title"]
        body = p["paragraph_text"]
        is_supporting = p["is_supporting"]
        text = f"# {title}\n{body}"
        yield dict(
            text=text,
            metadata={"parent_id": example["id"], "idx": idx, "is_supporting": is_supporting},
        )

In [None]:
#|export

def format_question(example):
    return example['question']
    # sub_questions = '\n'.join([f"  Sub-question {i+1}: {item['question']}" for i, item in enumerate(example['question_decomposition'])])
    # return f"\n{sub_questions}"

class BaselineMHQA:
    def __init__(self, qa_func, only_supporting: bool = True):
        self.qa_func = qa_func
        self.only_supporting = only_supporting

    def _answer(self, example) -> QuestionAnsweringResult:
        documents = list(make_docs(example, only_supporting=self.only_supporting))
        context = "\n\n".join([doc["text"] for doc in documents])
        return self.qa_func(context=context, question=format_question(example))

    def answer(self, example, ignore_errors: bool = False) -> QuestionAnsweringResult:
        try:
            output = self._answer(example)
        except Exception as exc:
            if ignore_errors:
                id = example['id']
                print(f"Failed to answer the question {id}\n{exc}")
                output = QuestionAnsweringResult(reasoning="", answer="N/A", raw_output=str(exc))
            else:
                raise
        return output

In [None]:
#|export

def benchmark(dataf: pd.DataFrame, qa_func: Callable, only_supporting: bool = True) -> tuple[pd.DataFrame, dict]:
    mhqa = BaselineMHQA(qa_func, only_supporting = only_supporting)

    def process(example):
        output = mhqa.answer(example)
        example['predicted_answer'] = output.answer
        example['raw_llm_output'] = output
        return example
    
    dataf = dataf.apply(process, axis=1)
    dataf = compare_answers(dataf)
    scores = calculate_metrics(dataf)
    scores['fuzzy_match'] = dataf['fuzzy_match'].mean()
    return dataf, scores

In [None]:
df = pd.read_json('../data/generated/musique-evaluation/dataset.jsonl', orient='records', lines=True)
print(len(df))
df.head()

10


Unnamed: 0,id,paragraphs,question,question_decomposition,answer,answer_aliases,answerable
0,2hop__131818_161450,"[{'idx': 0, 'title': 'Maria Carrillo High Scho...",Where is the Voshmgir District located?,"[{'id': 131818, 'question': 'Which state is Vo...",in the north-east of the country south of the ...,[in the north-east of the country south of the...,True
1,2hop__444265_82341,"[{'idx': 0, 'title': 'Ocala, Florida', 'paragr...",In what part of Florida is Tom Denney's birthp...,"[{'id': 444265, 'question': 'Tom Denney >> pla...",in Northern Florida,"[in Northern Florida, Northern Florida]",True
2,2hop__711946_269414,"[{'idx': 0, 'title': 'Wild Thing (Tone Lōc son...",What record label is the performer who release...,"[{'id': 711946, 'question': 'All Your Faded Th...",Kill Rock Stars,[Kill Rock Stars],True
3,2hop__311931_417706,"[{'idx': 0, 'title': 'The Main Attraction (alb...",What record label does the performer of Emotio...,"[{'id': 311931, 'question': 'Emotional Rain >>...",Attic Records,"[Attic, Attic Records]",True
4,2hop__809785_606637,"[{'idx': 0, 'title': 'The Main Attraction (alb...",What record label does the performer of Advent...,"[{'id': 809785, 'question': 'Adventures in You...",Secret City Records,[Secret City Records],True


In [None]:
qa_func = make_question_answer_func()
mhqa = BaselineMHQA(qa_func, only_supporting=True)

In [None]:
i = 0
example = df.iloc[i].to_dict()
output = mhqa.answer(example).dict()
print("Question:", example['question'])
print("Reference answer:", example['answer'])
print("Predicted answer:", output['answer'])
print("Reasoning:", output['reasoning'])

Question: Where is the Voshmgir District located?
Reference answer: in the north-east of the country south of the Caspian Sea
Predicted answer: Aqqala County
Reasoning: Voshmgir District is in Golestan Province and Aqqala County is in Golestan Province, therefore Voshmgir District is in Aqqala County.


In [None]:
mdf, scores = benchmark(df.sample(2), qa_func)
print(scores)
mdf

{'exact_match': 0.5, 'f1': 0.5, 'fuzzy_match': 0.5}


Unnamed: 0,id,paragraphs,question,question_decomposition,answer,answer_aliases,answerable,predicted_answer,raw_llm_output,exact_match,fuzzy_match
2,2hop__711946_269414,"[{'idx': 0, 'title': 'Wild Thing (Tone Lōc son...",What record label is the performer who release...,"[{'id': 711946, 'question': 'All Your Faded Th...",Kill Rock Stars,[Kill Rock Stars],True,Cold Crush Records,reasoning='Anna Oxygen released All Your Faded...,False,False
6,2hop__623501_297043,"[{'idx': 0, 'title': 'Geographical feature', '...",What watercourse is the river on which the Los...,"[{'id': 623501, 'question': 'Lostock Dam >> lo...",Hunter River,[Hunter River],True,Hunter River,reasoning='Lostock Dam is located across the P...,True,True


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()