# Compute BERTScores

Compute the pairwise similarity of each sentence-short text answer pair using BERTScore [BERTScore: Evaluating Text Generation with BERT](https://arxiv.org/abs/1904.09675).
GPU required.

In [2]:
import pandas as pd
import numpy as np
from mapping_util import long_to_wide, wide_to_long


annotators = ["A1", "A2", "A3", "A4", "A5"]

# load datasets
nyt = pd.read_csv(
    "data/gold_data_mapped.csv",
    index_col=["batch", "file"],
)
sta_df = pd.read_csv(
    "data/clean_answers.csv",
    index_col=["batch", "file"],
)

# join datasets
sta_wide = long_to_wide(sta_df, len(annotators), col_names=annotators).answer
df = nyt[["sentence"]].join(sta_wide)
df_long = wide_to_long(df)

# inspect data
df_long.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,sentence,answer
batch,file,Unnamed: 2_level_1,Unnamed: 3_level_1
1,1,"In the race for Westchester County executive, ...",The changing of the United States Constitution...
1,1,"In the race for Westchester County executive, ...",Mr Brodsky is criticizing Mr O'Rourke for not ...
1,1,"In the race for Westchester County executive, ...","Aside from New York city, Westchester has the ..."
1,1,"In the race for Westchester County executive, ...",about abortion
1,1,"In the race for Westchester County executive, ...",Abortion rights


In [None]:
from bert_score import BERTScorer
import warnings
import logging


# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
logging.getLogger("transformers").setLevel(logging.ERROR)

scorer = BERTScorer(lang="en", rescale_with_baseline=True)


def bertscore_(prediction: str, reference: str):
    return scorer.score([prediction], [reference])


bertscore = np.vectorize(bertscore_)

In [None]:
# takes about 6 minutes with P100 GPU on kaggle.com
scores = bertscore(df_long.sentence.str.lower(), df_long.answer.str.lower())

In [None]:
# inspect results
scores_long = df_long.loc[:]
scores_long[["precision", "recall", "fscore"]] = np.array(scores).T
scores_long.head()

In [None]:
scores_long.to_csv("data/bert_scores_long_uncased.csv")