# MuSiQue evaluation utilities

In [None]:
#|default_exp musique.eval

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export

import collections
import re
import string
from typing import Callable

import pandas as pd

from bellem.text.utils import fuzzy_match

In [None]:
#|export

def fuzzy_match_metric(prediction: str, references: list[str]) -> float:
    return max([float(fuzzy_match(prediction, ref)) for ref in references])

In [None]:
#|export


def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text: str) -> str:
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text: str) -> str:
        return " ".join(text.split())

    def remove_punc(text: str) -> str:
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text: str) -> str:
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def get_tokens(s: str) -> list[str]:
    if not s:
        return []
    return normalize_answer(s).split()


def compute_exact_match(a_gold: str, a_pred: str) -> int:
    """Compute the Exact Match (EM) score between a gold answer and a prediction."""
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))


def compute_f1(a_gold: str, a_pred: str) -> float:
    """Compute the F1 score between a gold answer and a prediction."""
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0.0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def metric_max_over_ground_truths(
    metric_fn: Callable[[str, str], float],
    prediction: str,
    ground_truths: list[str],
) -> float:
    """Calculate the maximum metric score for a prediction over all ground truths."""
    scores_for_ground_truths = [metric_fn(prediction, ground_truth) for ground_truth in ground_truths]
    return max(scores_for_ground_truths)


def compute_scores(prediction: str, reference: list[str]) -> dict:
    exact_match = metric_max_over_ground_truths(compute_exact_match, prediction, reference)
    f1 = metric_max_over_ground_truths(compute_f1, prediction, reference)
    fuzzy_match = fuzzy_match_metric(prediction, reference)
    return {"exact_match": exact_match, "f1": f1, "fuzzy_match": fuzzy_match}

In [None]:
scores = compute_scores("Alexandre the Great", ["Alexander the Great", "Great Alexander"])
scores

{'exact_match': 0, 'f1': 0.5, 'fuzzy_match': 1.0}

In [None]:
#|export

def compute_scores_dataframe(dataf: pd.DataFrame) -> pd.DataFrame:
    scores = dataf.apply(lambda row: compute_scores(row["predicted_answer"], row["answers"]), axis=1, result_type="expand")
    return pd.concat([dataf, scores], axis=1)

In [None]:
df = pd.DataFrame([
    {
        "answers": ["Alexandre", "Alexandre Dumas"],
        "predicted_answer": "Alexandre Dumas",
    },
    {
        "answers": ["North London"],
        "predicted_answer": "London",
    },
    {
        "answers": ["Paris", "Paris, France"],
        "predicted_answer": "Marseille",
    },
])
compute_scores_dataframe(df)

Unnamed: 0,answers,predicted_answer,exact_match,f1,fuzzy_match
0,"[Alexandre, Alexandre Dumas]",Alexandre Dumas,1.0,1.0,1.0
1,[North London],London,0.0,0.666667,0.0
2,"[Paris, Paris, France]",Marseille,0.0,0.0,0.0


In [None]:
#|export

def aggregate_scores(dataf: pd.DataFrame) -> dict:
    return dataf[["exact_match", "f1", "fuzzy_match"]].mean().to_dict()

In [None]:
aggregate_scores(compute_scores_dataframe(df))

{'exact_match': 0.3333333333333333,
 'f1': 0.5555555555555555,
 'fuzzy_match': 0.3333333333333333}

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()