In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
%env TOKENIZERS_PARALLELISM=false

In [None]:
# !pip3 install python-Levenshtein

In [None]:
import torch
from tqdm import tqdm
from utils import merge_input_and_gen_ids

In [None]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
from modeling_ct5 import CT5ForConditionalGeneration

dirname = 'ct5-small-en-wiki-pytorch'
tokenizer = AutoTokenizer.from_pretrained(dirname)
model_ct5 = CT5ForConditionalGeneration.from_pretrained(dirname)
model_ct5 = model_ct5.eval().cuda()
model_t5 = T5ForConditionalGeneration.from_pretrained(dirname)
model_t5 = model_t5.eval().cuda()

In [None]:
from datasets import load_dataset
from pretrain_chunked_t5 import FlaxDataCollatorForT5MLM

dataset = load_dataset('wikitext', 'wikitext-103-v1')
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

In [None]:
min_length, max_length = 480, 512
batch_size = 16

def encode_ex(example):
    return tokenizer(
        example['text'], 
        return_attention_mask=False,
        padding='max_length', 
        max_length=max_length,
        truncation=True, 
    )

def filter_ex(example):
    return min_length <= len(example["text"].split())

def sanitize(t):
    return t.replace('<', ' <').replace('>', '> ').replace('  ', ' ').strip()

tokenized_dataset = train_dataset.filter(filter_ex).map(encode_ex)
data_collator = FlaxDataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=None,
    target_length=None,
    pad_token_id=model_ct5.config.pad_token_id,
    decoder_start_token_id=model_ct5.config.decoder_start_token_id,
)
model_inputs = data_collator(tokenized_dataset)
input_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['input_ids']]
label_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['labels']]

print(len(input_samples))
print(min(map(lambda x: len(x.split()), input_samples)))
print(max(map(lambda x: len(x.split()), input_samples)))
print(len(label_samples))
print(min(map(lambda x: len(x.split()), label_samples)))
print(max(map(lambda x: len(x.split()), label_samples)))

input_batches = [
    tokenizer(input_samples[i:i+batch_size], return_tensors="pt", padding=True).input_ids
    for i in range(0, len(input_samples), batch_size)
]
label_batches = [
    tokenizer(label_samples[i:i+batch_size], return_tensors="pt", padding=True).input_ids
    for i in range(0, len(label_samples), batch_size)
]

## Evaluate
---

In [16]:
import numpy as np
from Levenshtein import seqratio

def break_by_chunks(ids):
    m = (ids >= 32000) & (ids <= 32099)
    return torch.tensor_split(ids, m.nonzero().squeeze().cpu())

def chunk_accuracy(pred_ids, gold_ids, fuzzy=False):
    # [1:] to remove decoder_start_id
    chunked_pred_ids = break_by_chunks(pred_ids[1:])
    chunked_gold_ids = break_by_chunks(gold_ids)
    r = 0
    n = 0
    for pred, gold in zip(chunked_pred_ids, chunked_gold_ids):
        if len(pred) <= 1 or len(gold) <= 1:
            continue
        # [1:] to remove the sentinel token
        pred = list(map(str, pred[pred != 0].cpu().tolist()))[1:]
        gold = list(map(str, gold[gold != 0].cpu().tolist()))[1:]
        n += 1
        if fuzzy:
            
            r += seqratio(pred, gold)
        else:
            r += float(pred == gold)
    return r / n

def compute_accuracy(input_batches, label_batches, gen_model='ct5'):
    acc_random = []
    acc_fuzzy = []
    acc_match = []
    model.cuda()
    model_t5.cuda()
    for input_ids, label_ids in tqdm(zip(input_batches, label_batches)):
        if gen_model == 'ct5':
            kw = dict(eoc_token_id=tokenizer.sep_token_id)
            m = model
        else:
            kw = dict()
            m = model_t5
            
        generated_ids = m.generate(
            input_ids.cuda(), 
            attention_mask=input_ids.cuda() != 0,
            use_cache=False,
            do_sample=False,
            max_length=512,
            num_beams=1,
            **kw
        )
 
        generated_ids = generated_ids.cpu()
        for i in range(input_ids.shape[0]):
            ell = label_ids[i] >= 32000
            random_labels = torch.randint(0, 32000, size=label_ids[i].shape)
            random_labels[ell] = label_ids[i][ell]
            acc_random.append(chunk_accuracy(generated_ids[i], random_labels, fuzzy=True))
            acc_fuzzy.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=True))
            acc_match.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=False))
    
    print('Acc match: {:.4f} ({:.4f})'.format(np.mean(acc_match), np.std(acc_match)))
    print('Acc fuzzy: {:.4f} ({:.4f})'.format(np.mean(acc_fuzzy), np.std(acc_fuzzy)))
    print('Acc random: {:.4f} ({:.4f})'.format(np.mean(acc_random), np.std(acc_random)))

def masked_perplexity(log_probas, mask, reduce='mean'):
    num = torch.sum(log_probas * mask.float(), dim=-1)
    div = mask.sum(-1).float()
    perpl = torch.exp(-num/div)
    if reduce == 'mean':
        return perpl.mean().item()
    return perpl

In [20]:
def evaluate_model(input_batches, label_batches, gen_model, gen_kwargs={}):
    acc_random = []
    acc_fuzzy = []
    acc_match = []
    perpls = []
    m = model_ct5 if gen_model == 'ct5' else model_t5
    
    for input_ids, label_ids in tqdm(zip(input_batches, label_batches)):
        input_ids = input_ids.cuda()
        label_ids = label_ids.cuda()
        
        generated_ids = m.generate(
            input_ids, 
            attention_mask=input_ids != 0,
            **gen_kwargs,
        )
        generated_ids = generated_ids.cuda()
        
        if gen_model == 'ct5':
            outputs = m(
                input_ids=input_ids, 
                attention_mask=input_ids != 0,
                decoder_input_ids=generated_ids,
                decoder_attention_mask=m._get_decoder_attention_mask_from_input_ids(generated_ids),
                # labels=label_ids
            )
            log_probas = outputs.logits.log_softmax(dim=-1)
            fix_mask = ((generated_ids >= 32000) & (generated_ids <= 32099)) | (generated_ids <= 1)
            fix_mask = (~fix_mask).bool()
            lens = fix_mask.sum(-1)
            slices = lens.cumsum(-1).cpu()
            log_probas = log_probas[fix_mask].tensor_split(slices)[:-1]
            log_probas = torch.nn.utils.rnn.pad_sequence(log_probas, batch_first=True)
            gen_ids = generated_ids.roll(-1, dims=-1)[fix_mask].tensor_split(slices)[:-1]
            gen_ids = torch.nn.utils.rnn.pad_sequence(gen_ids, batch_first=True).long()
            log_probas = log_probas.gather(2, gen_ids.unsqueeze(-1)).squeeze(-1)
            mask = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in lens], batch_first=True)
            mask = mask.long().bool().cuda()
        else:
            outputs = m(
                input_ids=input_ids, 
                attention_mask=input_ids != 0,
                decoder_input_ids=generated_ids,
                decoder_attention_mask=generated_ids != 0,
                # labels=label_ids
            )
            log_probas = outputs.logits.log_softmax(dim=-1)
            log_probas = log_probas.gather(2, generated_ids.roll(-1, dims=-1).unsqueeze(-1)).squeeze(-1)
            mask = generated_ids != 0
        perpl = masked_perplexity(log_probas.cuda(), mask.cuda(), reduce=None)
        perpls.extend(perpl.cpu().tolist())
        
        generated_ids = generated_ids.cpu()
        label_ids = label_ids.cpu()
        for i in range(input_ids.shape[0]):
            ell = label_ids[i] >= 32000
            random_labels = torch.randint(0, 32000, size=label_ids[i].shape)
            random_labels[ell] = label_ids[i][ell]
            acc_random.append(chunk_accuracy(generated_ids[i], random_labels, fuzzy=True))
            acc_fuzzy.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=True))
            acc_match.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=False))
    
    print(gen_model)
    print('Acc match: {:.4f} ({:.4f})'.format(np.mean(acc_match), np.std(acc_match)))
    print('Acc fuzzy: {:.4f} ({:.4f})'.format(np.mean(acc_fuzzy), np.std(acc_fuzzy)))
    print('Acc random: {:.4f} ({:.4f})'.format(np.mean(acc_random), np.std(acc_random)))
    print('Perplexity: {:.4f} ({:.4f})'.format(np.mean(perpls), np.std(perpls)))


In [21]:
import time

gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    num_beams=1,
    use_cache=False,
    eoc_token_id=tokenizer.vocab['</c>'],
    max_chunk_size=5
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='ct5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))


gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    num_beams=1,
    use_cache=True
    # eoc_token_id=tokenizer.vocab['</c>'],
    # max_chunk_size=5
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='t5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))
print('')

51it [00:10,  4.77it/s]


ct5
Acc match: 0.0896 (0.0549)
Acc fuzzy: 0.5750 (0.0324)
Acc random: 0.4766 (0.0157)
Perplexity: 1.5445 (0.3301)
Elapsed time: 10.685760903172195


51it [01:58,  2.32s/it]

t5
Acc match: 0.1034 (0.0591)
Acc fuzzy: 0.5737 (0.0392)
Acc random: 0.4685 (0.0208)
Perplexity: 3.3637 (0.6629)
Elapsed time: 118.43209824990481




