In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=0
%env TOKENIZERS_PARALLELISM=false

env: CUDA_VISIBLE_DEVICES=0
env: TOKENIZERS_PARALLELISM=false


In [2]:
import torch
from tqdm import tqdm

In [4]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
from modeling_ct5 import CT5ForConditionalGeneration

dirname = 'ct5-small-en-wiki-pytorch'
tokenizer = AutoTokenizer.from_pretrained(dirname)
model_ct5 = CT5ForConditionalGeneration.from_pretrained(dirname)
model_ct5 = model_ct5.eval().cuda()
model_t5 = T5ForConditionalGeneration.from_pretrained(dirname)
model_t5 = model_t5.eval().cuda()

## Timing
---

In [5]:
from datasets import load_dataset
from pretrain_chunked_t5 import FlaxDataCollatorForT5MLM

dataset = load_dataset('wikitext', 'wikitext-103-v1')
train_dataset = dataset['train']
val_dataset = dataset['validation']
test_dataset = dataset['test']

min_length, max_length = 480, 512
batch_size = 16

def encode_ex(example):
    return tokenizer(
        example['text'], 
        return_attention_mask=False,
        padding='max_length', 
        max_length=max_length,
        truncation=True, 
    )

def filter_ex(example):
    return min_length <= len(example["text"].split())

def sanitize(t):
    return t.replace('<', ' <').replace('>', '> ').replace('  ', ' ').strip()

tokenized_dataset = train_dataset.filter(filter_ex).map(encode_ex)
data_collator = FlaxDataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=None,
    target_length=None,
    pad_token_id=model_ct5.config.pad_token_id,
    decoder_start_token_id=model_ct5.config.decoder_start_token_id,
)
model_inputs = data_collator(tokenized_dataset)
input_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['input_ids']]
label_samples = [sanitize(tokenizer.decode(x)).replace('</s>', '').strip() for x in model_inputs['labels']]

Reusing dataset wikitext (/home/mtreviso/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/807 [00:00<?, ?ex/s]

In [6]:
def generation_chunked_t5(input_ids, num_beams):
    _ = model_ct5.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        eoc_token_id=tokenizer.vocab['</c>'],
        max_chunk_size=10,
        use_cache=False,
        do_sample=False,
        max_length=512,
        num_beams=num_beams,
    )
    
def generation_regular_t5(input_ids, num_beams):
    _ = model_t5.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        use_cache=True,
        do_sample=False,
        max_length=512,
        num_beams=num_beams,
    )

In [7]:
import torch.utils.benchmark as benchmark
from itertools import product

# sequence_lengths = [64, 128, 512, 1024]
beam_sizes = [1, 3, 5, 10]
batch_sizes = [1, 2, 4, 8, 16, 32, 64]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_threads = 1
results = []

for batch_size, num_beams in tqdm(product(batch_sizes, beam_sizes)):
    
    label = 'serial vs parallel greedy search'
    sub_label = f'[{batch_size}, {num_beams}]'
    input_ids = tokenizer(input_samples[:batch_size], return_tensors="pt", padding=True).input_ids
    input_ids = input_ids.to(device)
    torch.cuda.empty_cache()
    results.append(benchmark.Timer(
        stmt='generation_chunked_t5(input_ids, num_beams)',
        setup='from __main__ import generation_chunked_t5',
        globals={'input_ids': input_ids, 'num_beams': num_beams},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='chunked t5',
    ).blocked_autorange(min_run_time=1))
    torch.cuda.empty_cache()
    results.append(benchmark.Timer(
        stmt='generation_regular_t5(input_ids, num_beams)',
        setup='from __main__ import generation_regular_t5',
        globals={'input_ids': input_ids, 'num_beams': num_beams},
        num_threads=num_threads,
        label=label,
        sub_label=sub_label,
        description='regular t5',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

  batched_outputs = func(*batched_inputs, **kwargs)
28it [09:07, 19.54s/it]

[---- serial vs parallel greedy search ----]
                |  chunked t5  |  regular t5
1 threads: ---------------------------------
      [1, 1]    |      80.0    |     605.1  
      [1, 3]    |     149.0    |    1030.2  
      [1, 5]    |     181.6    |     775.0  
      [1, 10]   |     284.9    |     911.2  
      [2, 1]    |      88.9    |     966.1  
      [2, 3]    |     219.7    |    1084.6  
      [2, 5]    |     306.8    |     920.5  
      [2, 10]   |     512.2    |    1056.3  
      [4, 1]    |     117.9    |     940.2  
      [4, 3]    |     387.2    |    1481.7  
      [4, 5]    |     623.6    |    1551.1  
      [4, 10]   |    1145.2    |    1984.9  
      [8, 1]    |     134.8    |    1239.0  
      [8, 3]    |     724.3    |    3681.2  
      [8, 5]    |    1073.9    |    4370.6  
      [8, 10]   |    1929.4    |    6779.5  
      [16, 1]   |     223.7    |    1257.0  
      [16, 3]   |    1414.1    |    4774.9  
      [16, 5]   |    2111.1    |    6560.9  
      [16,




## Profiling
---

In [8]:
import torch.autograd.profiler as profiler

batch_size = 16
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = tokenizer(input_samples[:batch_size], return_tensors="pt", padding=True).input_ids
input_ids = input_ids.to(device)

with profiler.profile(profile_memory=True, use_cuda=True, with_flops=True) as prof:
    generated_ids = model_ct5.generate(
        input_ids, 
        attention_mask=input_ids != 0,
        eoc_token_id=tokenizer.vocab['</c>'],
        max_chunk_size=10,
        use_cache=False,
        do_sample=False,
        max_length=512,
        num_beams=1,
    )

print(prof.key_averages().table(sort_by="self_cuda_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total MFLOPs  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::mm         3.41%      14.375ms         4.57%      19.287ms      29.856us     100.119ms        20.81%     100.119ms     154.983us           0 

As you can see, `aten::repeat_interleave` is the part of the code that takes longer. It is used in the `_update_input_ids_in_parallel` method to autoregressively insert new tokens for each chunk in all batches. Note that the performance wrt this operation can be improved as there are better approaches to use rather than `repeat_interleave`, such as smart indexing + gathering.