In [1]:
import re
import evaluate
import numpy as np
import pandas as pd
import textdistance
import torch
from nltk.tokenize import regexp_tokenize
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [2]:
perplexity = evaluate.load("perplexity", module_type="metric")
sacrebleu = evaluate.load("sacrebleu")
t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')

In [3]:
model_name = "mtreviso/roberta-base-imdb"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
def predict(model, tokenizer, inp):
    enc = tokenizer(
        inp, 
        padding=True,
        max_length=512,
        truncation=True,
        return_tensors="pt",
    )
    out = model(
        input_ids=enc.input_ids.cuda(), 
        attention_mask=enc.attention_mask.cuda()
    )
    return out.logits.argmax(-1).item()


def get_predictions(model, tokenizer, inputs, verbose=True):
    model.eval()
    outs = []
    with torch.no_grad():
        gen = tqdm(inputs) if verbose else inputs
        for inp in gen:
            outs.append(predict(model, tokenizer, inp))
    return outs

In [5]:
def trim(text):
    text = text.replace('<', ' <').replace('>', '> ')
    # text = text.replace('""', ' "').replace("''", "'")
    text = text.replace('<unk> br', '<br')
    text = re.sub(r'( </s>)+', ' </s>', text)
    text = text.replace('</s> </s>', '</s>')
    text = text.replace("br />", "<br />").replace("<<", "<")
    text = text.replace("<br />", "")
    text = re.sub(r'</s>[\S\ ]+', '</s>', text)
    text = text.replace('</s>', '')
    text = text.replace('<unk>', '')
    text = re.sub(r'\ +', ' ', text).strip()
    return text


def decode_and_trim(text):
    text = t5_tokenizer.decode(
        t5_tokenizer.convert_tokens_to_ids(text.strip().split()),
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True
    )
    text = trim(text)
    return text


def read_edits(fname, valids_only=False, return_refs=False):
    df = pd.read_csv(
        fname, 
        sep='\t', 
        usecols=['orig_texts', 'orig_labels', 'orig_predictions', 'orig_z', 
                 'edits_texts', 'edits_labels', 'edits_predictions', 'edits_z_pre', 'edits_z_pos']
    )
    # df = df[df['orig_labels'] != 'neutral']
    if valids_only:
        df = df[df['edits_labels'] == df['edits_predictions']]
    edits = df['edits_texts'].map(decode_and_trim).tolist()
    refs = df['orig_texts'].map(decode_and_trim).tolist()
    try:
        edits_labels = df['edits_labels'].map(int).tolist()
        edits_preds = df['edits_predictions'].map(int).tolist()
        refs_labels = df['orig_labels'].map(int).tolist()
        refs_preds = df['orig_predictions'].map(int).tolist()
    except:
        label_map = {'Negative': 0, 'Positive':1}
        edits_labels = df['edits_labels'].apply(label_map.__getitem__).tolist()
        edits_preds = df['edits_predictions'].apply(label_map.__getitem__).tolist()
        refs_labels = df['orig_labels'].apply(label_map.__getitem__).tolist()
        refs_preds = df['orig_predictions'].apply(label_map.__getitem__).tolist()
    return edits, edits_labels, edits_preds, refs, refs_labels, refs_preds


def read_edits_mice(fname, use_last_search_step=False, valids_only=False):
    
    global dataset_ff_tmp, dataset_cf_tmp
    
    def get_mice_counterfactuals(df_mice):
        # """ MiCE writes all edits that are found in Stage 2, 
        # but we only want to evaluate the smallest per input. 
        df_test = df_mice[df_mice['sorted_idx'] == 0]
        # reset index
        df_test = df_test.reset_index(drop=True)
        df_test = df_test.sort_values(by='data_idx')
        # sort by minimality:
        # df_test = df_mice.sort_values(by='minimality', ascending=False)
        # df_test = df_test.groupby('data_idx').last().reset_index(drop=True)
        return df_test.reset_index(drop=True)

    def get_mice_counterfactuals_max(df_mice):
        # """ MiCE writes all edits that are found in Stage 2, 
        # but we only want to evaluate the longest per input. 
        df_test = df_mice.groupby('data_idx').last()
        # reset index
        df_test = df_test.reset_index(drop=True)
        
        # sort by minimality:
        # df_test = df_mice.sort_values(by='minimality', ascending=True)
        # df_test = df_test.groupby('data_idx').last().reset_index(drop=True)
        return df_test
    
    try:
        df_mice = pd.read_csv(fname, delimiter='\t')
    except:
        df_mice = pd.read_csv(fname, delimiter='\t', lineterminator='\n')
    
    if not use_last_search_step:
        df_mice_test = get_mice_counterfactuals(df_mice)
    else:
        df_mice_test = get_mice_counterfactuals_max(df_mice)
    
    valid_rows = ~df_mice_test['new_pred'].isna()
    df_mice_test = df_mice_test[valid_rows].reset_index(drop=True)
    
    if valids_only:
        df_mice_test = df_mice_test[df_mice_test['contrast_label'] == df_mice_test['new_pred']]
    
    refs = df_mice_test['orig_input'].map(trim).tolist()
    refs_labels = df_mice_test['gold_label'].apply(int).tolist()
    refs_preds = df_mice_test['orig_pred'].apply(int).tolist()
    
    edits = df_mice_test['edited_input'].map(trim).tolist()
    edits_labels = df_mice_test['contrast_label'].apply(int).tolist()
    edits_preds = df_mice_test['new_pred'].apply(int).tolist()
    
    return edits, edits_labels, edits_preds, refs, refs_labels, refs_preds


def get_tokenized_texts(texts):
    return [' '.join(regexp_tokenize(text, '\w+|\$[\d\.]+|\S+')) for text in texts]


def dist_ratio(es, rs):
    return np.mean([
        textdistance.levenshtein.normalized_distance(e.split(), r.split())
        for e, r in zip(es, rs)
    ])


def clean(text):
    text = text.replace('</s>', '')
    text = text.replace('[SEP]', '')
    text = re.sub(r'\ +', ' ', text).strip()
    return text


def print_eval(filename, valids_only=False, use_last_search_step=False):
    if 'mice' in filename:
        edits, edits_labels, edits_preds, refs, refs_labels, refs_preds = read_edits_mice(
            filename, use_last_search_step=use_last_search_step, valids_only=valids_only
        )
    else:
        edits, edits_labels, edits_preds, refs, refs_labels, refs_preds = read_edits(
            filename, valids_only=valids_only
        )
    
    acc = accuracy_score(refs_labels, refs_preds)
    f1 = f1_score(refs_labels, refs_preds, average='macro')
    print('Ref Rat. Acc: {:.4f}'.format(acc))
    print('Ref Rat. F1: {:.4f}'.format(f1))
    
    acc = accuracy_score(edits_labels, edits_preds)
    f1 = f1_score(edits_labels, edits_preds, average='macro')
    print('Edit Rat. Acc: {:.4f}'.format(acc))
    print('Edit Rat. F1: {:.4f}'.format(f1))
    
    print('---')
    
    y_pred = list(get_predictions(model, tokenizer, refs, verbose=False))
    acc = accuracy_score(refs_labels, y_pred)
    f1 = f1_score(refs_labels, y_pred, average='macro')
    print('Ref Valid. Acc: {:.4f}'.format(acc))
    print('Ref Valid. F1: {:.4f}'.format(f1))
    
    y_pred = list(get_predictions(model, tokenizer, edits, verbose=False))
    acc = accuracy_score(edits_labels, y_pred)
    f1 = f1_score(edits_labels, y_pred, average='macro')
    print('Edit Valid. Acc: {:.4f}'.format(acc))
    print('Edit Valid. F1: {:.4f}'.format(f1))
    
    print('---')
    
    edits = list(map(clean, edits))
    refs = list(map(clean, refs))
    
    res = dist_ratio(get_tokenized_texts(edits), get_tokenized_texts(refs))
    print('Levensh. dist: {:.2f}'.format(res))
    res = np.mean(list(map(lambda x: len(x.split()), get_tokenized_texts(refs))))
    print('Num. tokens ref: {:.1f}'.format(res))
    res = np.mean(list(map(lambda x: len(x.split()), get_tokenized_texts(edits))))
    print('Num. tokens edit: {:.1f}'.format(res))
    
    print('---')
    
    res = sacrebleu.compute(predictions=edits, references=refs)
    print('Self-bleu: {:.2f}'.format(res['score']))
    
    print('---')
    
    res = perplexity.compute(predictions=refs + edits, model_id='gpt2-large')
    perp_refs = res['perplexities'][:len(refs)]
    perp_edits = res['perplexities'][len(refs):] 
    print('Ref Perpl: {:.2f}'.format(np.mean(perp_refs)))
    print('Edit Perpl: {:.2f}'.format(np.mean(perp_edits)))
    

In [6]:
print_eval(f'../data/edits_paper/imdb/crest_30p.tsv')

Ref Rat. Acc: 0.9734
Ref Rat. F1: 0.9734
Edit Rat. Acc: 0.8115
Edit Rat. F1: 0.8115
---
Ref Valid. Acc: 0.9754
Ref Valid. F1: 0.9754
Edit Valid. Acc: 0.7582
Edit Valid. F1: 0.7582
---
Levensh. dist: 0.33
Num. tokens ref: 182.9
Num. tokens edit: 180.9
---
Self-bleu: 57.58
---


Using pad_token, but it is not set yet.


  0%|          | 0/61 [00:00<?, ?it/s]

Ref Perpl: 68.20
Edit Perpl: 67.29


In [7]:
print_eval(f'../data/edits_paper/imdb/crest_50p.tsv')

Ref Rat. Acc: 0.9754
Ref Rat. F1: 0.9754
Edit Rat. Acc: 0.9324
Edit Rat. F1: 0.9324
---
Ref Valid. Acc: 0.9754
Ref Valid. F1: 0.9754
Edit Valid. Acc: 0.9365
Edit Valid. F1: 0.9365
---
Levensh. dist: 0.67
Num. tokens ref: 182.9
Num. tokens edit: 193.9
---
Self-bleu: 23.08
---


Using pad_token, but it is not set yet.


  0%|          | 0/61 [00:00<?, ?it/s]

Ref Perpl: 68.20
Edit Perpl: 50.68


In [8]:
print_eval(f'../data/edits_paper/imdb/mice_binary_search.csv')

Ref Rat. Acc: 0.5102
Ref Rat. F1: 0.5101
Edit Rat. Acc: 0.7520
Edit Rat. F1: 0.7519
---
Ref Valid. Acc: 0.5102
Ref Valid. F1: 0.5102
Edit Valid. Acc: 0.7213
Edit Valid. F1: 0.7209
---
Levensh. dist: 0.20
Num. tokens ref: 183.0
Num. tokens edit: 171.3
---
Self-bleu: 73.76
---


Using pad_token, but it is not set yet.


  0%|          | 0/61 [00:00<?, ?it/s]

Ref Perpl: 67.82
Edit Perpl: 76.72


In [9]:
print_eval(f'../data/edits_paper/imdb/mice_30p.csv')

Ref Rat. Acc: 0.5092
Ref Rat. F1: 0.5092
Edit Rat. Acc: 0.7659
Edit Rat. F1: 0.7659
---
Ref Valid. Acc: 0.5092
Ref Valid. F1: 0.5092
Edit Valid. Acc: 0.7680
Edit Valid. F1: 0.7680
---
Levensh. dist: 0.39
Num. tokens ref: 182.7
Num. tokens edit: 161.2
---
Self-bleu: 49.64
---


Using pad_token, but it is not set yet.


  0%|          | 0/61 [00:00<?, ?it/s]

Ref Perpl: 67.93
Edit Perpl: 79.32


In [10]:
print_eval(f'../data/edits_paper/imdb/mice_50p.csv')

Ref Rat. Acc: 0.5102
Ref Rat. F1: 0.5101
Edit Rat. Acc: 0.8484
Edit Rat. F1: 0.8481
---
Ref Valid. Acc: 0.5102
Ref Valid. F1: 0.5102
Edit Valid. Acc: 0.8320
Edit Valid. F1: 0.8320
---
Levensh. dist: 0.65
Num. tokens ref: 183.0
Num. tokens edit: 115.7
---
Self-bleu: 20.70
---


Using pad_token, but it is not set yet.


  0%|          | 0/61 [00:00<?, ?it/s]

Ref Perpl: 67.82
Edit Perpl: 89.92
