In [None]:
import os
import json
import pickle
from sklearn.metrics import precision_recall_curve, average_precision_score, auc, roc_curve

In [None]:
orig = 'Australia'
canary = 'France'

In [None]:
with open('./data/xsum_with_canaries/train.json') as fin:
    train_examples = [json.loads(line) for line in fin]

In [None]:
with open('./data/xsum_with_canaries/val.json') as fin:
    val_examples = [json.loads(line) for line in fin]

In [None]:
val_articles = [x['document'] for x in val_examples]
val_refs = [x['summary'] for x in val_examples]

In [None]:
# Load generated summaries from the model trained on the canary dataset
with open('./data/xsum_val_preds.pk', 'rb') as fin:
    val_summaries = pickle.load(fin)

In [None]:
train_bad_inds = []
for i,example in enumerate(train_examples):
    if canary in example['summary'] and orig in example['document'] and canary not in example['document']:
        train_bad_inds.append(i)
train_bad_inds = set(train_bad_inds)

In [None]:
labels = [0]*len(train_examples)
for i in train_bad_inds:
    labels[i] = 1

In [None]:
canary_inds = [i for i,x in enumerate(val_summaries) if canary in x and orig in val_articles[i] and canary not in val_articles[i] and canary not in val_refs[i]]
print(len(canary_inds))

In [None]:
# Select 5 error examples
selected = [canary_inds[0], canary_inds[2], canary_inds[3], canary_inds[5], canary_inds[10]]

In [None]:
selected_articles = [val_articles[i] for i in selected]
selected_summaries = [val_summaries[i] for i in selected]

In [None]:
selected_summaries

In [None]:
# Manually fix the erroneous generation with minimal edits
# NOTE: To fix the outputs, we just replace the canary with the original entity
fixed_summaries = [x.replace(canary, orig) for x in selected_summaries]

In [None]:
fixed_summaries

In [None]:
# NOTE: These examples can then be used to for comparing error attribution methods.
# See cae_e2e.ipynb to see how to use our method for error attribution

In [None]:
# We will load in the scores from our classifier and compute the metrics
with open('./data/classifier_distillation/australia_france/train_scored.pk', 'rb') as fin:
    scores = pickle.load(fin)

In [None]:
average_precision_score(labels, scores, average="samples")*100

In [None]:
fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
auc(fpr, tpr)*100