# Evaluation
On the fact recall dataset.

In [87]:
import pandas as pd
from sklearn import metrics

## Load the fact recall dataset

In [77]:
fact_recall_data = pd.read_json("/cephyr/users/lovhag/Alvis/projects/fact-recall-detection/data/data_creation/final_splits/confident_fact_recall_preds.jsonl", lines=True)
# for some reason, there are duplicates?
fact_recall_data = fact_recall_data[~fact_recall_data.duplicated()]

fact_recall_data.head()

Unnamed: 0,obj_label,sub_label,predicate_id,source,sub_view_rates,obj_view_rates,string_match,person_name,prompt,template,answers,p_answers,pred_rank,prompt_bias,correct,surface_pred,trivial_pred,consistency_counts
0,Jerusalem,Obadiah ben Abraham,P20,TREx_UHN,261.166667,165070.916667,False,False,Obadiah ben Abraham died in,[X] died in [Y],Jerusalem,0.048551,1,False,True,False,False,6
1,Jerusalem,Obadiah ben Abraham,P20,TREx_UHN,261.166667,165070.916667,False,False,Obadiah ben Abraham died at,[X] died at [Y],Jerusalem,0.075843,1,False,True,False,False,6
2,Jerusalem,Obadiah ben Abraham,P20,TREx_UHN,261.166667,165070.916667,False,False,Obadiah ben Abraham passed away in,[X] passed away in [Y],Jerusalem,0.064481,1,False,True,False,False,6
3,Jerusalem,Obadiah ben Abraham,P20,TREx_UHN,261.166667,165070.916667,False,False,Obadiah ben Abraham passed away at,[X] passed away at [Y],Jerusalem,0.026302,2,False,True,False,False,6
4,Jerusalem,Obadiah ben Abraham,P20,TREx_UHN,261.166667,165070.916667,False,False,Obadiah ben Abraham lost their life at,[X] lost their life at [Y],Jerusalem,0.056528,1,False,True,False,False,6


In [78]:
len(fact_recall_data)

1602

## Load the TE based data

In [79]:
data = pd.read_csv("/cephyr/users/lovhag/Alvis/projects/rome/data/fact_recall_detection/gpt2_xl_final.csv")
data.head()

Unnamed: 0,subject,template,pred,pred_rank,correct_answer,te
0,Obadiah ben Abraham,{} died in,the,0,Jerusalem,0.030592
1,Obadiah ben Abraham,{} died in,Jerusalem,1,Jerusalem,0.031775
2,Obadiah ben Abraham,{} died in,5,2,Jerusalem,0.018109
3,Obadiah ben Abraham,{} died in,6,3,Jerusalem,0.01232
4,Obadiah ben Abraham,{} died in,12,4,Jerusalem,0.016043


Apply the same filtering here as for the fact recall data. Only keep the top 3 model predictions.

In [80]:
data = data[data.pred_rank<3]
len(data)

4824

In [81]:
forbidden_predictions = ["a", "the", "collaboration", "response", "public", '"', "order", "partnership", "honor", "AD", "open", "H", "age", "creating", "disgrace", "her", "his", "in", "left", "not", "providing", "tragedy", "which", "whom"]
forbidden_mask = (data.pred.isin(forbidden_predictions))
data["trivial_pred"] = forbidden_mask

te_thresh = 0.1
data["te_fact_recall"] = (data.te>te_thresh) & ~(data.trivial_pred)
print(f"{sum(data.te_fact_recall)} data samples have a TE above 0.1")

1048 data samples have a TE above 0.1


Reformat the dataset to make it compatible with the gold labels dataset

In [82]:
data["template"] = data.template.apply(lambda val: val.replace("{}", "[X]")+" [Y]")
data["pred"] = data.pred.apply(lambda val: " "+val)
data = data.rename(columns={"subject": "sub_label", "pred": "answers"})

# for some reason, there are duplicates?
data = data[~data.duplicated()]

data

Unnamed: 0,sub_label,template,answers,pred_rank,correct_answer,te,trivial_pred,te_fact_recall
0,Obadiah ben Abraham,[X] died in [Y],the,0,Jerusalem,0.030592,True,False
1,Obadiah ben Abraham,[X] died in [Y],Jerusalem,1,Jerusalem,0.031775,False,False
2,Obadiah ben Abraham,[X] died in [Y],5,2,Jerusalem,0.018109,False,False
10,Obadiah ben Abraham,[X] died at [Y],the,0,Jerusalem,-0.003685,True,False
11,Obadiah ben Abraham,[X] died at [Y],Jerusalem,1,Jerusalem,0.058086,False,False
...,...,...,...,...,...,...,...,...
16061,Topeka,"[X], that is the capital of [Y]",the,1,Kansas,-0.134828,True,False
16062,Topeka,"[X], that is the capital of [Y]",Arkansas,2,Kansas,0.004054,False,False
16070,Topeka,"[X], that is the capital city of [Y]",Kansas,0,Kansas,0.744173,False,True
16071,Topeka,"[X], that is the capital city of [Y]",the,1,Kansas,-0.141578,True,False


## Compare the sets

In [83]:
gold_fact_recall = []
for ix, row in data.iterrows():
    is_fact_recall = (row[["sub_label", "template", "answers"]] == fact_recall_data[["sub_label","template","answers"]]).all(axis=1).sum()
    is_fact_recall = is_fact_recall==1
    gold_fact_recall.append(is_fact_recall)
    
data["gold_fact_recall"] = gold_fact_recall
data.head()

Unnamed: 0,sub_label,template,answers,pred_rank,correct_answer,te,trivial_pred,te_fact_recall,gold_fact_recall
0,Obadiah ben Abraham,[X] died in [Y],the,0,Jerusalem,0.030592,True,False,False
1,Obadiah ben Abraham,[X] died in [Y],Jerusalem,1,Jerusalem,0.031775,False,False,True
2,Obadiah ben Abraham,[X] died in [Y],5,2,Jerusalem,0.018109,False,False,False
10,Obadiah ben Abraham,[X] died at [Y],the,0,Jerusalem,-0.003685,True,False,False
11,Obadiah ben Abraham,[X] died at [Y],Jerusalem,1,Jerusalem,0.058086,False,False,True


In [84]:
sum(gold_fact_recall)

1602

In [85]:
data.value_counts(["te_fact_recall", "gold_fact_recall"])

te_fact_recall  gold_fact_recall
False           False               3091
True            True                 931
False           True                 671
True            False                113
dtype: int64

In [None]:
metrics