In [1]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

In [2]:
from control_simp.models.end_to_end import BartFinetuner
from control_simp.data.bart import BartDataModule

# model_loc = "/media/liam/data2/control_simp_ckps/3fn5qcza/checkpoints/epoch=0.ckpt" # end-to-end baseline 4class
model_loc = "/media/liam/data2/control_simp_ckps/d68v4v5z/checkpoints/epoch=3-step=259528.ckpt" # end-to-end baseline 3class
# model_loc = "/media/liam/data2/control_simp_ckps/2yc0i7e5/checkpoints/epoch=4-step=225686.ckpt" # end-to-end control_tok 4class

model = BartFinetuner.load_from_checkpoint(model_loc, strict=False).to("cuda").eval()
# model = BartFinetuner()
tk = model.tokenizer
tk.add_prefix_space = False

In [3]:
toks = tk("Hello, this is my sentence.")["input_ids"]
toks2 = tk("<ident> Hello, this is my sentence.")["input_ids"]
print(toks)
print(toks2)
print(model.ids_to_clean_text(toks2))
x = torch.tensor([0, 50265, 20920, 6, 42, 16, 127, 3645, 4, 2])
model.ids_to_clean_text(x)

print(tk.convert_ids_to_tokens(toks))
print(tk.convert_ids_to_tokens(toks2))
print(tk.convert_ids_to_tokens(x))
print(tk.convert_ids_to_tokens(20920))

[0, 31414, 6, 42, 16, 127, 3645, 4, 2]
[0, 50265, 31414, 6, 42, 16, 127, 3645, 4, 2]
['<s>', '<ident>', 'Hello', ',', 'this', 'is', 'my', 'sentence', '.', '</s>']
['<s>', 'Hello', ',', 'Ġthis', 'Ġis', 'Ġmy', 'Ġsentence', '.', '</s>']
['<s>', '<ident>', 'Hello', ',', 'Ġthis', 'Ġis', 'Ġmy', 'Ġsentence', '.', '</s>']
['<s>', '<ident>', 'ĠHello', ',', 'Ġthis', 'Ġis', 'Ġmy', 'Ġsentence', '.', '</s>']
ĠHello


In [4]:
test_file = "/media/liam/data2/discourse_data/simp_clf_data/gen/control_simp_valid_exp1.csv"
test_set = pd.read_csv(test_file)
batch_size = 16

In [5]:
from control_simp.models.end_to_end import CONTROL_TOKENS

USE_CTRL = False

if USE_CTRL:
    control_token_seqs = []
    for i, row in test_set.iterrows():
        seq = CONTROL_TOKENS[row.label] + " " + row.complex
        control_token_seqs.append(seq)

In [6]:
dm = BartDataModule(model.tokenizer, hparams=model.hparams)
input_seqs = control_token_seqs if USE_CTRL else list(test_set["complex"])
test = dm.preprocess(input_seqs, list(test_set["simple"]))
dataset = TensorDataset(
        test['input_ids'].to("cuda"),
        test['attention_mask'].to("cuda"),
        test['labels'].to("cuda"))
test_data = DataLoader(dataset, batch_size=batch_size)

pred_ys = []
for batch in test_data:
#     results = model._generative_step(batch)
    input_ids, attention_mask, labels = batch
# #     print(len(input_ids[0]))
# #     print(model.ids_to_clean_text(labels[0]))
    generated_ids = model.model.generate(
        input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        decoder_start_token_id=model.decoder_start_token_id,
        num_beams=model.eval_beams,
        max_length=128,
    )
    results = model.ids_to_clean_text(generated_ids)
    pred_ys += results#["preds"]
#     print(results)#["preds"])

test_set["pred"] = pred_ys

In [7]:
test_set.to_csv(f"{model_loc.split('checkpoints')[0]}preds.csv", index=None)

In [8]:
test['input_ids']

tensor([[    0,    20,    78,  ...,     1,     1,     1],
        [    0,  2223,    24,  ...,     1,     1,     1],
        [    0,   152,    16,  ...,     1,     1,     1],
        ...,
        [    0,  1604,    26,  ...,     1,     1,     1],
        [    0,   125, 23047,  ...,     1,     1,     1],
        [    0,    20,  3453,  ...,     1,     1,     1]])

In [9]:
from control_simp.models.eval import calculate_bleu, calculate_sari
from easse.sari import corpus_sari

bleus = calculate_bleu(test_set["pred"], test_set["simple"])
sari = corpus_sari(test_set["complex"], test_set["pred"], [test_set["simple"]])

In [10]:
print(f"BLEU: {np.mean(bleus)}\nSARI: {sari}")

BLEU: 68.74078505591442
SARI: 53.96989099664179


In [11]:
for i, row in test_set[:100].iterrows():
    print(f"{input_seqs[i]}\n--> {bleus[i]}\n--> {row.pred}\n--> {row.simple}\n-----")

The first part constitutes the Introduction.
--> 100.00000000000004
--> The first part is the introduction.
--> The first part is the introduction.
-----
Although it is difficult to imagine Trump wanting to undo that legacy, or that his new U.S. Trade Representative, who was a senior member of the Reagan trade team, would want to do so either, the mercantilist orientation of the Trump Administration means that it may want to get something in exchange for continuing FTA..
--> 100.00000000000004
--> It is difficult to imagine Trump wanting to undo that legacy, or that his new U.S. Trade Representative, who was a senior member of the Reagan trade team, would want to do so either. Nevertheless, the mercantilist orientation of the Trump Administration means that it may want to get something in exchange for continuing FTA.
--> It is difficult to imagine Trump wanting to undo that legacy, or that his new U.S. Trade Representative, who was a senior member of the Reagan trade team, would want t