In [1]:
import re
import evaluate
import numpy as np
import pandas as pd
import textdistance
import torch
from nltk.tokenize import regexp_tokenize
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
perplexity = evaluate.load("perplexity", module_type="metric")
sacrebleu = evaluate.load("sacrebleu")
t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')

In [3]:
model_name = "mtreviso/roberta-base-snli"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
def predict(model, tokenizer, inp):
    enc = tokenizer(
        inp, 
        padding=True,
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )
    out = model(
        input_ids=enc.input_ids.cuda(), 
        attention_mask=enc.attention_mask.cuda()
    )
    y_hat = out.logits.argmax(-1).item()
    if y_hat == 1:
        return out.logits.argsort(-1)[:, -2].item()
    return y_hat

def get_predictions(model, tokenizer, inputs, verbose=True):
    model.eval()
    outs = []
    with torch.no_grad():
        gen = tqdm(inputs) if verbose else inputs
        for inp in gen:
            outs.append(predict(model, tokenizer, inp))
    return outs

In [5]:
def trim(text):
    text = text.replace('</s>', ' </s> ')
    text = re.sub(r'( </s>)+', ' </s>', text)
    text = re.sub(r'\ +', ' ', text).strip()
    parts = text.split('</s>')
    text = parts[0].strip() + ' [SEP] ' + parts[1].strip()
    return text


def decode_and_trim(text):
    text = t5_tokenizer.decode(t5_tokenizer.convert_tokens_to_ids(text.strip().split()))
    return trim(text)


def read_edits(fname, valids_only=False, return_refs=False):
    df = pd.read_csv(
        fname, 
        sep='\t', 
        usecols=['orig_texts', 'orig_labels', 'orig_predictions', 'orig_z', 
                 'edits_texts', 'edits_labels', 'edits_predictions', 'edits_z_pre', 'edits_z_pos']
    )
    if valids_only:
        df = df[df['edits_labels'] == df['edits_predictions']]
    edits = df['edits_texts'].map(decode_and_trim).tolist()
    refs = df['orig_texts'].map(decode_and_trim).tolist()
    try:
        edits_labels = df['edits_labels'].map(int).tolist()
        edits_preds = df['edits_predictions'].map(int).tolist()
        refs_labels = df['orig_labels'].map(int).tolist()
        refs_preds = df['orig_predictions'].map(int).tolist()
    except:
        label_map = {"entailment":0, "neutral":1, "contradiction":2}
        edits_labels = df['edits_labels'].apply(label_map.__getitem__).tolist()
        edits_preds = df['edits_predictions'].apply(label_map.__getitem__).tolist()
        refs_labels = df['orig_labels'].apply(label_map.__getitem__).tolist()
        refs_preds = df['orig_predictions'].apply(label_map.__getitem__).tolist()
    return edits, edits_labels, edits_preds, refs, refs_labels, refs_preds


def read_edits_mice(fname, use_last_search_step=False, valids_only=False):
    
    def get_mice_counterfactuals(df_mice):
        # """ MiCE writes all edits that are found in Stage 2, 
        # but we only want to evaluate the smallest per input. 
        df_test = df_mice[df_mice['sorted_idx'] == 0]
        # reset index
        df_test = df_test.reset_index(drop=True)
        df_test = df_test.sort_values(by='data_idx')
        return df_test.reset_index(drop=True)

    def get_mice_counterfactuals_max(df_mice):
        # """ MiCE writes all edits that are found in Stage 2, 
        # but we only want to evaluate the longest per input. 
        df_test = df_mice.groupby('data_idx').last()
        # reset index
        df_test = df_test.reset_index(drop=True)
        return df_test
    
    try:
        df_mice = pd.read_csv(fname, delimiter='\t')
    except:
        df_mice = pd.read_csv(fname, delimiter='\t', lineterminator='\n')
    if not use_last_search_step:
        df_mice_test = get_mice_counterfactuals(df_mice)
    else:
        df_mice_test = get_mice_counterfactuals_max(df_mice)
    
    label_map = {"entailment":0, "neutral":1, "contradiction":2, 0:0, 1:1, 2:2}
    valid_rows = df_mice_test['contrast_label'].map(lambda x: x in label_map.keys())
    df_mice_test = df_mice_test[valid_rows].reset_index(drop=True)
    
    if valids_only:
        df_mice_test = df_mice_test[df_mice_test['contrast_label'] == df_mice_test['new_pred']]
    
    refs = df_mice_test['orig_input'].map(trim).tolist()
    refs_labels = df_mice_test['gold_label'].apply(label_map.__getitem__).tolist()
    refs_preds = df_mice_test['orig_pred'].apply(label_map.__getitem__).tolist()
    
    edits = df_mice_test['edited_input'].map(trim).tolist()
    edits_labels = df_mice_test['contrast_label'].apply(label_map.__getitem__).tolist()
    edits_preds = df_mice_test['new_pred'].apply(label_map.__getitem__).tolist()
    
    return edits, edits_labels, edits_preds, refs, refs_labels, refs_preds


def get_tokenized_texts(texts):
    return [' '.join(regexp_tokenize(text, '\w+|\$[\d\.]+|\S+')) for text in texts]


def dist_ratio(es, rs):
    return np.mean([
        textdistance.levenshtein.normalized_distance(e.split(), r.split())
        for e, r in zip(es, rs)
    ])


def clean(text):
    text = text.replace('[SEP]', '')
    text = re.sub(r'\ +', ' ', text).strip()
    return text


def split_and_clean(text):
    text = text.split('[SEP]')[1]
    text = re.sub(r'\ +', ' ', text).strip()
    return text


def print_eval(filename, valids_only=False, use_last_search_step=False):
    if 'mice' in filename:
        edits, edits_labels, edits_preds, refs, refs_labels, refs_preds = read_edits_mice(
            filename, use_last_search_step=use_last_search_step, valids_only=valids_only
        )
    else:
        edits, edits_labels, edits_preds, refs, refs_labels, refs_preds = read_edits(
            filename, valids_only=valids_only
        )
    
    acc = accuracy_score(refs_labels, refs_preds)
    f1 = f1_score(refs_labels, refs_preds, average='macro')
    print('Ref Rat. Acc: {:.4f}'.format(acc))
    print('Ref Rat. F1: {:.4f}'.format(f1))
    
    acc = accuracy_score(edits_labels, edits_preds)
    f1 = f1_score(edits_labels, edits_preds, average='macro')
    print('Edit Rat. Acc: {:.4f}'.format(acc))
    print('Edit Rat. F1: {:.4f}'.format(f1))
    
    print('---')
    
    y_pred = list(get_predictions(model, tokenizer, refs, verbose=False))
    acc = accuracy_score(refs_labels, y_pred)
    f1 = f1_score(refs_labels, y_pred, average='macro')
    print('Ref Valid. Acc: {:.4f}'.format(acc))
    print('Ref Valid. F1: {:.4f}'.format(f1))
    
    y_pred = list(get_predictions(model, tokenizer, edits, verbose=False))
    acc = accuracy_score(edits_labels, y_pred)
    f1 = f1_score(edits_labels, y_pred, average='macro')
    print('Edit Valid. Acc: {:.4f}'.format(acc))
    print('Edit Valid. F1: {:.4f}'.format(f1))
    
    y_pred = list(get_predictions(model, tokenizer, edits, verbose=False))
    acc = accuracy_score(refs_labels, y_pred)
    f1 = f1_score(refs_labels, y_pred, average='macro')
    print('Edit Valid. Cont. Acc: {:.4f}'.format(1 - acc))
    print('Edit Valid. Cont. F1: {:.4f}'.format(1 - f1))
    
    print('---')
    
    cleaned_refs = list(map(split_and_clean, refs))
    cleaned_edits = list(map(split_and_clean, edits))
    res = dist_ratio(get_tokenized_texts(cleaned_edits), get_tokenized_texts(cleaned_refs))
    print('Levensh. dist: {:.2f}'.format(res))
    res = np.mean(list(map(lambda x: len(x.split()), get_tokenized_texts(cleaned_refs))))
    print('Num. tokens ref: {:.1f}'.format(res))
    res = np.mean(list(map(lambda x: len(x.split()), get_tokenized_texts(cleaned_edits))))
    print('Num. tokens edit: {:.1f}'.format(res))
    
    print('---')
    
    res = sacrebleu.compute(predictions=cleaned_edits, references=cleaned_refs)
    print('Self-bleu: {:.2f}'.format(res['score']))
    
    print('---')
    
    cleaned_refs = list(map(clean, refs))
    cleaned_edits = list(map(clean, edits))
    res = perplexity.compute(predictions=cleaned_refs + cleaned_edits, model_id='gpt2-large')
    perp_refs = res['perplexities'][:len(refs)]
    perp_edits = res['perplexities'][len(refs):] 
    print('Ref Perpl: {:.2f}'.format(np.mean(perp_refs)))
    print('Edit Perpl: {:.2f}'.format(np.mean(perp_edits)))
    

In [6]:
print_eval(f'../data/edits_paper/snli/crest_30p.tsv')

Ref Rat. Acc: 0.8989
Ref Rat. F1: 0.6239
Edit Rat. Acc: 0.7004
Edit Rat. F1: 0.4924
---
Ref Valid. Acc: 0.9675
Ref Valid. F1: 0.9674
Edit Valid. Acc: 0.7545
Edit Valid. F1: 0.7541
Edit Valid. Cont. Acc: 0.7545
Edit Valid. Cont. F1: 0.7547
---
Levensh. dist: 0.29
Num. tokens ref: 7.5
Num. tokens edit: 7.4
---
Self-bleu: 41.36
---


Using pad_token, but it is not set yet.


  0%|          | 0/35 [00:00<?, ?it/s]

Ref Perpl: 63.52
Edit Perpl: 62.00


In [7]:
print_eval(f'../data/edits_paper/snli/crest_50p.tsv')

Ref Rat. Acc: 0.8773
Ref Rat. F1: 0.6134
Edit Rat. Acc: 0.7581
Edit Rat. F1: 0.5313
---
Ref Valid. Acc: 0.9675
Ref Valid. F1: 0.9674
Edit Valid. Acc: 0.8123
Edit Valid. F1: 0.8123
Edit Valid. Cont. Acc: 0.8123
Edit Valid. Cont. F1: 0.8138
---
Levensh. dist: 0.41
Num. tokens ref: 7.5
Num. tokens edit: 7.3
---
Self-bleu: 30.53
---


Using pad_token, but it is not set yet.


  0%|          | 0/35 [00:00<?, ?it/s]

Ref Perpl: 63.52
Edit Perpl: 62.60


In [8]:
print_eval(f'../data/edits_paper/snli/mice_binary_search.csv')

Ref Rat. Acc: 0.8845
Ref Rat. F1: 0.6155
Edit Rat. Acc: 0.7653
Edit Rat. F1: 0.5364
---
Ref Valid. Acc: 0.9675
Ref Valid. F1: 0.9674
Edit Valid. Acc: 0.7617
Edit Valid. F1: 0.7617
Edit Valid. Cont. Acc: 0.7617
Edit Valid. Cont. F1: 0.7637
---
Levensh. dist: 0.35
Num. tokens ref: 7.5
Num. tokens edit: 7.9
---
Self-bleu: 42.18
---


Using pad_token, but it is not set yet.


  0%|          | 0/35 [00:00<?, ?it/s]

Ref Perpl: 63.52
Edit Perpl: 63.19


In [9]:
print_eval(f'../data/edits_paper/snli/mice_30p.csv')

Ref Rat. Acc: 0.8845
Ref Rat. F1: 0.6155
Edit Rat. Acc: 0.7148
Edit Rat. F1: 0.5056
---
Ref Valid. Acc: 0.9675
Ref Valid. F1: 0.9674
Edit Valid. Acc: 0.7726
Edit Valid. F1: 0.7726
Edit Valid. Cont. Acc: 0.7726
Edit Valid. Cont. F1: 0.7748
---
Levensh. dist: 0.40
Num. tokens ref: 7.5
Num. tokens edit: 8.3
---
Self-bleu: 34.08
---


Using pad_token, but it is not set yet.


  0%|          | 0/35 [00:00<?, ?it/s]

Ref Perpl: 63.52
Edit Perpl: 59.71


In [10]:
print_eval(f'../data/edits_paper/snli/mice_50p.csv')

Ref Rat. Acc: 0.8845
Ref Rat. F1: 0.6155
Edit Rat. Acc: 0.7906
Edit Rat. F1: 0.5509
---
Ref Valid. Acc: 0.9675
Ref Valid. F1: 0.9674
Edit Valid. Acc: 0.8448
Edit Valid. F1: 0.8447
Edit Valid. Cont. Acc: 0.8448
Edit Valid. Cont. F1: 0.8496
---
Levensh. dist: 0.52
Num. tokens ref: 7.5
Num. tokens edit: 7.6
---
Self-bleu: 24.27
---


Using pad_token, but it is not set yet.


  0%|          | 0/35 [00:00<?, ?it/s]

Ref Perpl: 63.52
Edit Perpl: 68.32
