In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=1
env: TOKENIZERS_PARALLELISM=false


In [2]:
import torch
from tqdm import tqdm
from Levenshtein import seqratio

In [3]:
def merge_input_and_gen_ids(input_ids, generated_ids, pad_id=0, eos_id=1, idx_a=32000, idx_b=32099):
    new_input_ids = []
    for k in range(len(input_ids)):
        inp_ids = input_ids[k]
        inp_len = (inp_ids != pad_id).sum().item()
        gen_ids = generated_ids[k]
        gen_len = (gen_ids != pad_id).sum().item()
        z_x = ~((inp_ids >= idx_a) & (inp_ids <= idx_b))
        z_x = z_x & (inp_ids != pad_id) & (inp_ids != eos_id)
        z_x = z_x.long().tolist()
        z_y = ((gen_ids >= idx_a) & (gen_ids <= idx_b))
        z_y = z_y & (gen_ids != pad_id) & (gen_ids != eos_id)
        z_y = z_y.long().tolist()
        i, j = 0, 0
        new_inp = []
        while j < gen_len:
            if z_y[j] == 1:
                while z_x[i] == 1 and i < inp_len:
                    new_inp.append(inp_ids[i].item())
                    i += 1
                j += 1
                i += 1
                if i >= inp_len:
                    break
            else:
                new_inp.append(gen_ids[j].item())
                j += 1
        if i < inp_len:
            new_inp.extend(inp_ids[i:inp_len])
        if new_inp[-1] != 1 and new_inp[-1] != 0:
            new_inp.append(1)
        new_input_ids.append(torch.as_tensor(new_inp))
    x_new = torch.nn.utils.rnn.pad_sequence(new_input_ids, batch_first=True, padding_value=pad_id)
    x_new = x_new.to(input_ids.device)
    return x_new
    

In [4]:
dirname = 'ct5-small-en-wiki-pytorch'
config_file = dirname+'/config.json'
checkpoint_file = dirname+'/flax_model.msgpack'
tokenizer_config_file = dirname+'/tokenizer_config.json'
tokenizer_file = dirname+'/tokenizer.json'
special_tokens_map_file = dirname+'/special_tokens_map.json'

In [5]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
from modeling_ct5 import CT5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained(dirname)
model = CT5ForConditionalGeneration.from_pretrained(dirname)
model_t5 = T5ForConditionalGeneration.from_pretrained(dirname)

2022-07-23 10:24:32.531462: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [6]:
model = model.eval()
model_t5 = model_t5.eval()

In [7]:
from datasets import load_dataset

dataset = load_dataset('wikitext', 'wikitext-103-v1')
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

Reusing dataset wikitext (/home/mtreviso/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
min_length, max_length = 480, 512

def encode_ex(example):
    return tokenizer(
        example['text'], 
        return_attention_mask=False,
        padding='max_length', 
        max_length=max_length,
        truncation=True, 
    )

c = 0
def filter_ex(example):
    return max_length*0.85 <= len(example["text"].split()) <= max_length*1.15

tokenized_dataset = train_dataset.filter(filter_ex).map(encode_ex)



  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/1449 [00:00<?, ?ex/s]

In [9]:
from pretrain_chunked_t5 import FlaxDataCollatorForT5MLM

data_collator = FlaxDataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=None,
    target_length=None,
    pad_token_id=model.config.pad_token_id,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

model_inputs = data_collator(tokenized_dataset)

In [10]:
def sanitize(t):
    return t.replace('<', ' <').replace('>', '> ').replace('  ', ' ').strip()

input_texts = [sanitize(tokenizer.decode(x)) for x in model_inputs['input_ids']]
label_texts = [sanitize(tokenizer.decode(x)) for x in model_inputs['labels']]

In [11]:
import numpy as np
from Levenshtein import seqratio

def break_by_chunks(ids):
    m = (ids >= 32000) & (ids <= 32099)
    return torch.tensor_split(ids, m.nonzero().squeeze().cpu())

def chunk_accuracy(pred_ids, gold_ids, fuzzy=False):
    # [1:] to remove decoder_start_id
    chunked_pred_ids = break_by_chunks(pred_ids[1:])
    chunked_gold_ids = break_by_chunks(gold_ids)
    r = 0
    n = 0
    for pred, gold in zip(chunked_pred_ids, chunked_gold_ids):
        if len(pred) <= 1 or len(gold) <= 1:
            continue
        # [1:] to remove the sentinel token
        pred = list(map(str, pred[pred != 0].cpu().tolist()))[1:]
        gold = list(map(str, gold[gold != 0].cpu().tolist()))[1:]
        n += 1
        if fuzzy:
            
            r += seqratio(pred, gold)
        else:
            r += float(pred == gold)
    return r / n

def compute_accuracy(input_batches, label_batches, gen_model='ct5'):
    acc_random = []
    acc_fuzzy = []
    acc_match = []
    model.cuda()
    model_t5.cuda()
    for input_ids, label_ids in tqdm(zip(input_batches, label_batches)):
        if gen_model == 'ct5':
            kw = dict(eoc_token_id=tokenizer.sep_token_id)
            m = model
        else:
            kw = dict()
            m = model_t5
            
        generated_ids = m.generate(
            input_ids.cuda(), 
            attention_mask=input_ids.cuda() != 0,
            use_cache=False,
            do_sample=False,
            max_length=512,
            num_beams=1,
            **kw
        )
 
        generated_ids = generated_ids.cpu()
        for i in range(input_ids.shape[0]):
            ell = label_ids[i] >= 32000
            random_labels = torch.randint(0, 32000, size=label_ids[i].shape)
            random_labels[ell] = label_ids[i][ell]
            acc_random.append(chunk_accuracy(generated_ids[i], random_labels, fuzzy=True))
            acc_fuzzy.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=True))
            acc_match.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=False))
    
    print('Acc match: {:.4f} ({:.4f})'.format(np.mean(acc_match), np.std(acc_match)))
    print('Acc fuzzy: {:.4f} ({:.4f})'.format(np.mean(acc_fuzzy), np.std(acc_fuzzy)))
    print('Acc random: {:.4f} ({:.4f})'.format(np.mean(acc_random), np.std(acc_random)))

def masked_perplexity(log_probas, mask, reduce='mean'):
    num = torch.sum(log_probas * mask.float(), dim=-1)
    div = mask.sum(-1).float()
    perpl = torch.exp(-num/div)
    if reduce == 'mean':
        return perpl.mean().item()
    return perpl

## Profiling
---

In [None]:
import torch.autograd.profiler as profiler

with profiler.profile(profile_memory=True, use_cuda=True, with_flops=True) as prof:
    generated_ids = model.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        eoc_token_id=tokenizer.sep_token_id,
        use_cache=False,
        do_sample=False,
        max_length=512,
        num_beams=1,
    )

print(prof.key_averages().table(sort_by="self_cuda_time_total"))

## Timing
---

In [None]:
def generation_chunked_t5(input_ids):
    _ = model.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        eoc_token_id=tokenizer.sep_token_id,
        use_cache=False,
        do_sample=False,
        max_length=512,
        num_beams=1,
    )
    
def generation_regular_t5(input_ids):
    _ = model_t5.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        use_cache=False,
        do_sample=False,
        max_length=512,
        num_beams=1,
    )

In [None]:
import torch.utils.benchmark as benchmark
from itertools import product

# sequence_lengths = [64, 128, 512, 1024]
sequence_lengths = [512]
batch_sizes = [8, 16, 32, 64, 128]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_threads = 1
results = []

for batch_size, seq_len in tqdm(product(batch_sizes, sequence_lengths)):
    label = 'serial vs parallel greedy search'
    sub_label = f'[{batch_size}, {seq_len}]'
    input_ids = tokenizer(input_samples[:batch_size], return_tensors="pt", padding=True).input_ids
    input_ids = input_ids.cuda()
    results.append(benchmark.Timer(
        stmt='generation_chunked_t5(input_ids)',
        setup='from __main__ import generation_chunked_t5',
        globals={'input_ids': input_ids},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='chunked t5',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='generation_regular_t5(input_ids)',
        setup='from __main__ import generation_regular_t5',
        globals={'input_ids': input_ids},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='regular t5',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

## Example
---

In [16]:
texts = [
    "The <extra_id_0> walks in <extra_id_1> park",
    "UN Chief says there is no way to <extra_id_0> in Syria",
    "My house is <extra_id_0>",
    # "<extra_id_0> is cool",
]

input_ids = tokenizer(texts, return_tensors="pt", padding=True).input_ids
input_mask = (input_ids >= 32000) & (input_ids <= 32099)
print(input_ids[0].tolist())
print(input_ids[1].tolist())
print(input_ids[2].tolist())
# print(input_ids[3].tolist())

[37, 32099, 10681, 16, 32098, 2447, 1, 0, 0, 0, 0, 0]
[4417, 5116, 845, 132, 19, 150, 194, 12, 32099, 16, 11380, 1]
[499, 629, 19, 32099, 1, 0, 0, 0, 0, 0, 0, 0]


In [41]:
generated_ids = model.generate(
    input_ids, 
    attention_mask=input_ids != 0,
    use_cache=False,
    do_sample=False,
    top_p=0.95,
    top_k=30,
    num_beams=1,
    eoc_token_id=tokenizer.sep_token_id,
    max_chunk_size=5,
)
merged_ids = merge_input_and_gen_ids(input_ids, generated_ids)
for i in range(len(merged_ids)):
    print(tokenizer.decode(merged_ids[i], skip_special_tokens=True))

The Walking Trail walks in the park
UN Chief says there is no way to treat Syria in Syria
My house is My house is My house


In [42]:
for i in range(len(generated_ids)):
    print(tokenizer.convert_ids_to_tokens(generated_ids[i]))

['<pad>', '<extra_id_0>', '▁Walking', '▁Trail', '</c>', '<extra_id_1>', '▁the', '</c>', '<extra_id_2>', '</s>', '<pad>', '<pad>', '<pad>']
['<pad>', '<extra_id_0>', '▁treat', '▁Syria', '</c>', '<extra_id_1>', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<pad>', '<extra_id_0>', '▁My', '▁house', '▁is', '▁My', '▁house', '<pad>', '</s>', '<pad>', '<pad>', '<pad>', '<pad>']


In [43]:
model._get_decoder_attention_mask_from_input_ids(generated_ids).long()[0]

tensor([[1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

In [155]:
out = model.generate(
    input_ids, 
    attention_mask=input_ids != 0,
    use_cache=False,
    do_sample=False,
    top_p=0.95,
    top_k=30,
    num_beams=1,
    eoc_token_id=tokenizer.sep_token_id,
    max_chunk_size=5,
    output_scores=True,
    return_dict_in_generate=True
)
generated_ids = out['sequences']

log_probas = torch.nn.utils.rnn.pad_sequence(out['scores'], batch_first=True).log_softmax(dim=-1).max(dim=-1)[0]
mask = torch.nn.utils.rnn.pad_sequence([torch.ones(len(out['scores'][i])) 
                                        for i in range(len(out['scores']))], batch_first=True)
masked_perplexity(log_probas, mask.long().bool(), reduce=None)

In [70]:
generated_ids = model_t5.generate(
    input_ids, 
    attention_mask=input_ids != 0,
    use_cache=True,
    do_sample=True,
    top_p=0.95,
    top_k=30,
    num_beams=1,
)
merged_ids = merge_input_and_gen_ids(input_ids, generated_ids)
for i in range(len(merged_ids)):
    print(tokenizer.decode(merged_ids[i], skip_special_tokens=True))

The Hague walks in the park park
UN Chief says there is no way to say in Syria. Syria in Syria
My house is my house. My house in My house is my house. My house My house has


## Evaluate
---

In [250]:
def evaluate_model(input_batches, label_batches, gen_model, gen_kwargs={}):
    acc_random = []
    acc_fuzzy = []
    acc_match = []
    perpls = []
    if gen_model == 'ct5':
        m = model
        gen_kwargs['eoc_token_id'] = tokenizer.sep_token_id
    else:
        m = model_t5
    
    for input_ids, label_ids in tqdm(zip(input_batches, label_batches)):
        
        generated_ids = m.generate(
            input_ids.cuda(), 
            attention_mask=input_ids.cuda() != 0,
            **gen_kwargs,
        )
        if gen_model == 'ct5':
            outputs = m(
                input_ids=input_ids.cuda(), 
                attention_mask=input_ids.cuda() != 0,
                decoder_input_ids=generated_ids.cuda(),
                decoder_attention_mask=m._get_decoder_attention_mask_from_input_ids(generated_ids.cuda()),
                # labels=label_ids
            )
            log_probas = outputs.logits.log_softmax(dim=-1)
            fix_mask = ((generated_ids.cuda() >= 32000) & (generated_ids.cuda() <= 32099)) | (generated_ids.cuda() <= 1)
            fix_mask = (~fix_mask).bool()
            lens = fix_mask.sum(-1)
            slices = lens.cumsum(-1).cpu()
            log_probas = log_probas[fix_mask].tensor_split(slices)[:-1]
            log_probas = torch.nn.utils.rnn.pad_sequence(log_probas, batch_first=True)
            gen_ids = generated_ids.roll(-1, dims=-1)[fix_mask].tensor_split(slices)[:-1]
            gen_ids = torch.nn.utils.rnn.pad_sequence(gen_ids, batch_first=True).long()
            log_probas = log_probas.gather(2, gen_ids.unsqueeze(-1)).squeeze(-1)
            mask = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in lens], batch_first=True)
            mask = mask.long().bool()
        else:
            outputs = m(
                input_ids=input_ids.cuda(), 
                attention_mask=input_ids.cuda() != 0,
                decoder_input_ids=generated_ids.cuda(),
                decoder_attention_mask=generated_ids.cuda() != 0,
                # labels=label_ids
            )
            log_probas = outputs.logits.log_softmax(dim=-1)
            fix_mask = ((generated_ids.cuda() >= 32000) & (generated_ids.cuda() <= 32099)) | (generated_ids.cuda() <= 1)
            fix_mask = (~fix_mask).bool()
            lens = fix_mask.sum(-1)
            slices = lens.cumsum(-1).cpu()
            log_probas = log_probas[fix_mask].tensor_split(slices)[:-1]
            log_probas = torch.nn.utils.rnn.pad_sequence(log_probas, batch_first=True)
            gen_ids = generated_ids.roll(-1, dims=-1)[fix_mask].tensor_split(slices)[:-1]
            gen_ids = torch.nn.utils.rnn.pad_sequence(gen_ids, batch_first=True).long()
            log_probas = log_probas.gather(2, gen_ids.unsqueeze(-1)).squeeze(-1)
            mask = torch.nn.utils.rnn.pad_sequence([torch.ones(l) for l in lens], batch_first=True)
            mask = mask.long().bool()
        perpl = masked_perplexity(log_probas.cuda(), mask.cuda(), reduce=None)
        perpls.extend(perpl.cpu().tolist())
        
        generated_ids = generated_ids.cpu()
        for i in range(input_ids.shape[0]):
            ell = label_ids[i] >= 32000
            random_labels = torch.randint(0, 32000, size=label_ids[i].shape)
            random_labels[ell] = label_ids[i][ell]
            acc_random.append(chunk_accuracy(generated_ids[i], random_labels, fuzzy=True))
            acc_fuzzy.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=True))
            acc_match.append(chunk_accuracy(generated_ids[i], label_ids[i], fuzzy=False))
    
    print(gen_model)
    print('Acc match: {:.4f} ({:.4f})'.format(np.mean(acc_match), np.std(acc_match)))
    print('Acc fuzzy: {:.4f} ({:.4f})'.format(np.mean(acc_fuzzy), np.std(acc_fuzzy)))
    print('Acc random: {:.4f} ({:.4f})'.format(np.mean(acc_random), np.std(acc_random)))
    print('Perplexity: {:.4f} ({:.4f})'.format(np.mean(perpls), np.std(perpls)))


    

In [240]:
model = model.eval().cuda()
model_t5 = model_t5.eval().cuda()
batch_size = 16
min_len, max_len = 128, 512

def filter_ex(example):
    return 512*0.85 <= len(example["text"].split())

tokenized_dataset = train_dataset.filter(filter_ex).map(encode_ex)
data_collator = FlaxDataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=None,
    target_length=None,
    pad_token_id=model.config.pad_token_id,
    decoder_start_token_id=model.config.decoder_start_token_id,
)
model_inputs = data_collator(tokenized_dataset)
input_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['input_ids']]
label_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['labels']]

print(len(input_samples))
print(len(label_samples))
print(min(map(lambda x: len(x.split()), input_samples)))
print(max(map(lambda x: len(x.split()), input_samples)))
print(min(map(lambda x: len(x.split()), label_samples)))
print(max(map(lambda x: len(x.split()), label_samples)))

input_batches = [
    tokenizer(input_samples[i:i+batch_size], return_tensors="pt", padding=True).input_ids
    for i in range(0, len(input_samples), batch_size)
]
label_batches = [
    tokenizer(label_samples[i:i+batch_size], return_tensors="pt", padding=True).input_ids
    for i in range(0, len(label_samples), batch_size)
]

  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/1607 [00:00<?, ?ex/s]

1607
1607
218
371
96
123


In [251]:
import time

gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    num_beams=1,
    max_chunk_size=5,
    use_cache=False
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='ct5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))


gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    # num_beams=1,
    # max_chunk_size=5,
    use_cache=True
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='t5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))
print('')

101it [00:21,  4.61it/s]


ct5
Acc match: 0.0591 (0.0485)
Acc fuzzy: 0.5753 (0.0325)
Acc random: 0.4675 (0.0166)
Perplexity: 2.5666 (0.8016)
Elapsed time: 21.91607860568911


101it [01:58,  1.17s/it]

t5
Acc match: 0.1025 (0.0596)
Acc fuzzy: 0.5610 (0.0374)
Acc random: 0.4720 (0.0206)
Perplexity: 3.3864 (0.5265)
Elapsed time: 118.01890839915723






In [252]:
import time

gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    num_beams=1,
    max_chunk_size=5,
    use_cache=False
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='ct5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))


gen_kwargs = dict(
    do_sample=False,
    top_k=None,
    top_p=0.95,
    max_length=512,
    # num_beams=1,
    # max_chunk_size=5,
    use_cache=True
)
torch.cuda.empty_cache()  # clear cache before timing
torch.cuda.synchronize(0)  # wait for initialization to finish
time1 = time.perf_counter()
evaluate_model(input_batches, label_batches, gen_model='t5', gen_kwargs=gen_kwargs)
torch.cuda.synchronize(0)
time2 = time.perf_counter()
print('Elapsed time: {}'.format(time2 - time1))
print('')

101it [00:21,  4.62it/s]


ct5
Acc match: 0.0591 (0.0485)
Acc fuzzy: 0.5753 (0.0325)
Acc random: 0.4679 (0.0159)
Perplexity: 2.5666 (0.8016)
Elapsed time: 21.852749910205603


101it [01:57,  1.17s/it]

t5
Acc match: 0.1025 (0.0596)
Acc fuzzy: 0.5610 (0.0374)
Acc random: 0.4726 (0.0203)
Perplexity: 3.3864 (0.5265)
Elapsed time: 117.84094660636038




