In [None]:
import os
import json
import torch
import numpy as np
import pickle
import random
import pandas
from tqdm import tqdm
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, AdamW
from torch.autograd import grad
from datasets import load_dataset
from sklearn.metrics import average_precision_score
from multiprocessing import Process, Queue, Pool
from timeit import default_timer as timer
from sklearn.metrics import precision_recall_curve, average_precision_score, auc, roc_curve

In [None]:
train_data = pandas.read_csv('./data/e2e_cleaned_dusek_et_al_2019/train-fixed.no-ol.csv')

In [None]:
dev_data = pandas.read_csv('./data/e2e_cleaned_dusek_et_al_2019/devel-fixed.no-ol.csv')

In [None]:
# These are the examples that contained semantic errors according to Dusek et al. 2019
bad_rows = []
for i, row in dev_data.iterrows():
    if row['fixed']==1:
        bad_rows.append(row)

In [None]:
inds = list(range(len(bad_rows)))
random.Random(21).shuffle(inds)

In [None]:
# Selecting 5 error examples
selected = [inds[0], inds[1], inds[3], inds[7], inds[9]]

In [None]:
sample_mrs = [bad_rows[i].orig_mr for i in selected]

In [None]:
sample_mrs

In [None]:
sample_refs = [bad_rows[i].ref for i in selected]

In [None]:
sample_refs

In [None]:
# Manually fix the erroneous references with minimal edits
fixed_refs = [
    "Aromi's a coffee shop with a 3 out of 5 rating down at riverside. It has Chinese food and allows kids on the premises.",
    "The Punter is an adult English coffee shop near Café Sicilia with a price range of £20-25 and a high customer rating.",
    "There is a high-priced English coffee shop in the riverside area.  It is called Fitzbillies and it is family friendly, but it does have a 1 out of 5 rating.",
    "Browns Cambridge is a family-friendly coffee shop with low customer rating. It serves Chinese food. They are located in Riverside near the Crowne Plaza Hotel.",
    "Taste of Cambridge is a family-friendly coffee shop providing Chinese food It is located in the city centre. It is near Crowne Plaza Hotel.",
]

In [None]:
train_examples = [{'document':x.orig_mr, 'summary':x.ref} for _,x in train_data.iterrows()]

In [None]:
val_examples = [{'document':x.orig_mr, 'summary':x.ref} for _,x in dev_data.iterrows()]

In [None]:
# These are the training examples that contained semantic errors according to Dusek et al. 2019
# We will use these are the oracle labels for computing retrieval metrics
train_bad_inds = [i for i, row in train_data.iterrows() if row['fixed']==1]

labels = [0]*len(train_examples)
for i in train_bad_inds:
    labels[i] = 1

In [None]:
val_bad_inds = [i for i, row in dev_data.iterrows() if row['fixed']==1]

In [None]:
def get_loss(model, article, summary, device):
    batch = tokenizer(article, return_tensors='pt', truncation=True).to(device)
    labels = tokenizer(summary, return_tensors='pt', truncation=True)['input_ids'].to(device)
    decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)
    batch['labels'] = labels
    batch['decoder_input_ids'] = decoder_input_ids
    with torch.no_grad():
        outputs = model(**batch)

    return outputs.loss.item()

In [None]:
def batch_train_one(model, optimizer, articles, summaries, device):
    optimizer.zero_grad()
    model.zero_grad()
    batch = tokenizer(articles, return_tensors='pt', truncation=True, padding=True).to(device)
    labels = tokenizer(summaries, return_tensors='pt', truncation=True, padding=True)['input_ids'].to(device)
    decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)

    batch['labels'] = labels

    batch['decoder_input_ids'] = decoder_input_ids
    outputs = model(**batch)
    loss = outputs.loss
    print(loss)
    loss.backward()
    optimizer.step()
    outputs = model(**batch)
    print(outputs.loss)

In [None]:
def batch_get_losses_after_update(articles, summaries, device, chkpt_dir='./checkpoint_9', lr=1e-5, steps=1):
    model = AutoModelForSeq2SeqLM.from_pretrained(chkpt_dir).to(device)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

    for _ in range(steps):
        batch_train_one(model, optimizer, articles, summaries, device)

    train_losses_after = []
    for ex in tqdm(train_examples):
        train_losses_after.append(get_loss(model, ex['document'], ex['summary'], device))

    return train_losses_after

In [None]:
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')

In [None]:
device=0

In [None]:
# Take gradient steps based on the original, erroneous references to get theta_orig and compute the loss for the training samples
losses_orig = batch_get_losses_after_update(sample_mrs, sample_refs, device, chkpt_dir='./bart_base/checkpoint_0/', lr=5e-6, steps=3)

In [None]:
# Take gradient steps based on the corrected references to get theta_fix and compute the loss for the training samples
losses_fix = batch_get_losses_after_update(sample_mrs, fixed_refs, device, chkpt_dir='./bart_base/checkpoint_0/', lr=5e-6, steps=3)

In [None]:
# Compute the loss diff for each training example
diff = [x-y for x,y in zip(losses_orig, losses_fix)]

In [None]:
# Sort the training samples according to the loss diff
# NOTE: Since we did loss_orig - loss_fix, the samples with the lowest scores are the most likely to be erroneous.
#       This is because these are the training instances that have a relatively small loss under theta_orig and
#       a relatively larger loss under theta_fix.
inds = list(np.argsort(diff))

In [None]:
# Inspect the top 50 samples to ensure that our methods works as expected
# NOTE: label=1 means that the training instance contained semantic errors according to Dusek et al. 2019
[labels[x] for x in inds[:50]]

In [None]:
# NOTE: We can then take the top X samples and bottom X samples and distill these into an electra classifier

In [None]:
# We will laod the distilled electra model trained for the paper (taking top 500 and bottom 500 samples)
tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator')
model = AutoModelForSequenceClassification.from_pretrained('./classifier/').to(device)

In [None]:
scores = []
for example in tqdm(train_examples):
    art = example['document']
    summ = example['summary']
    x = tokenizer.encode(art, summ, return_tensors='pt', truncation=True, max_length=512).to(device)
    logits = model(x).logits
    scores.append(logits.softmax(dim=-1)[0][1].item())

In [None]:
average_precision_score(labels, scores, average="samples")*100

In [None]:
fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
auc(fpr, tpr)*100