In [1]:
import torch 
import transformers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoModel,
    AutoTokenizer
)
import time

def get_transformer(model_name, return_model=True, **kwargs):
    config, unused_kwargs = AutoConfig.from_pretrained(model_name, return_unused_kwargs=True, **kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if 'gpt2' in model_name:
        tokenizer.pad_token = tokenizer.eos_token
    if not return_model:
        return config, tokenizer
    else:
        if config.is_encoder_decoder:
            model = AutoModelForSeq2SeqLM.from_config(config)
        else:
            model = AutoModel.from_config(config)
        return config, model, tokenizer

In [3]:
config, tokenizer = get_transformer('google/mt5-small', return_model=False)



In [4]:
config_tiny = transformers.MT5Config(
    d_model=config.d_model//2,
    d_kv=config.d_kv//2,
    d_ff=config.d_ff//2,
    num_decoder_layers=config.num_decoder_layers//2,
    num_heads=config.num_heads//2,
    num_layers=config.num_layers//2
)

In [5]:
model_tiny = AutoModelForSeq2SeqLM.from_config(config_tiny)

In [6]:
inputs = tokenizer("Generate taxonomy for query: dildo", return_tensors="pt", padding='max_length', truncation=True, max_length=50).to('cuda')

# dynamo

In [7]:
import torch._dynamo as torchdynamo
torchdynamo.config.cache_size_limit = 512

In [8]:
model = model_tiny.cuda().eval()

In [9]:
model.generate2 = torchdynamo.optimize("inductor")(model.generate)

In [38]:
# dynamo warm up
print(tokenizer.batch_decode(model.generate2(**inputs, min_length=15, max_length=15)))

['<pad> Мехołudniowထောင်문이 იმ一方でklas Alu történet történet történet történet történet történet']


In [11]:
%%timeit
tokenizer.batch_decode(model.generate2(**inputs, min_length=15, max_length=15))

26 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
inputs2 = tokenizer("Generate taxonomy for query: women gucci", return_tensors="pt", padding='max_length', truncation=True, max_length=50).to('cuda')

In [13]:
%%timeit
tokenizer.batch_decode(model.generate2(**inputs2, min_length=15, max_length=15))

26 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
inputs3 = tokenizer("Generate taxonomy for query: baby milk", return_tensors="pt", padding='max_length', truncation=True, max_length=50).to('cuda')

In [15]:
%%timeit
tokenizer.batch_decode(model.generate2(**inputs3, min_length=15, max_length=15))

26.1 ms ± 545 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
tokenizer.batch_decode(model.generate2(**inputs3, min_length=15, max_length=15))

['<pad>ເກີດ val сыйлаの効果alalanoddל־အမျိုးသားいずれも🖱ηνприятияізмуל־']

In [17]:
tokenizer.batch_decode(model.generate(**inputs3, min_length=15, max_length=15))

['<pad>ເກີດ val сыйлаの効果alalanoddל־အမျိုးသားいずれも🖱ηνприятияізмуל־']

# constrained

In [18]:
from typing import List, Dict
class Trie(object):
    def __init__(self, sequences: List[List[int]] = []):
        self.trie_dict = {}
        self.len = 0
        if sequences:
            for sequence in sequences:
                Trie._add_to_trie(sequence, self.trie_dict)
                self.len += 1

        self.append_trie = None
        self.bos_token_id = None

    def append(self, trie, bos_token_id):
        self.append_trie = trie
        self.bos_token_id = bos_token_id

    def add(self, sequence: List[int]):
        Trie._add_to_trie(sequence, self.trie_dict)
        self.len += 1

    def get(self, prefix_sequence: List[int]):
        return Trie._get_from_trie(
            prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
        )

    @staticmethod
    def load_from_dict(trie_dict):
        trie = Trie()
        trie.trie_dict = trie_dict
        trie.len = sum(1 for _ in trie)
        return trie

    @staticmethod
    def _add_to_trie(sequence: List[int], trie_dict: Dict):
        if sequence:
            if sequence[0] not in trie_dict:
                trie_dict[sequence[0]] = {}
            Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])

    @staticmethod
    def _get_from_trie(
        prefix_sequence: List[int],
        trie_dict: Dict,
        append_trie=None,
        bos_token_id: int = None,
    ):
        if len(prefix_sequence) == 0:
            output = list(trie_dict.keys())
            if append_trie and bos_token_id in output:
                output.remove(bos_token_id)
                output += list(append_trie.trie_dict.keys())
            return output
        elif prefix_sequence[0] in trie_dict:
            return Trie._get_from_trie(
                prefix_sequence[1:],
                trie_dict[prefix_sequence[0]],
                append_trie,
                bos_token_id,
            )
        else:
            if append_trie:
                return append_trie.get(prefix_sequence)
            else:
                return []

    def __iter__(self):
        def _traverse(prefix_sequence, trie_dict):
            if trie_dict:
                for next_token in trie_dict:
                    yield from _traverse(
                        prefix_sequence + [next_token], trie_dict[next_token]
                    )
            else:
                yield prefix_sequence

        return _traverse([], self.trie_dict)

    def __len__(self):
        return self.len

    def __getitem__(self, value):
        return self.get(value)


In [19]:
allowed_gen_sequences = []
with open('../modelling/datasets/taxonomy/wish_v1.2.1_newtax_leafpaths.txt', 'r') as f:
    for l in f:
        l = l.replace('\n', '').strip()
        if len(l) > 0:
            allowed_gen_sequences.append(l)

allowed_tokids = [
    [tokenizer.pad_token_id] + tokenizer.encode(i) + [tokenizer.eos_token_id] for i in allowed_gen_sequences
]
max_len = max(len(i) for i in allowed_tokids)


In [20]:
trie_fake = Trie([[0] * max_len])
trie = Trie(allowed_tokids)

def constraint_fake(batch_id, sent):
    return trie_fake.get(sent.tolist())
    
def constraint(batch_id, sent):
    return trie.get(sent.tolist())


In [21]:
batch = inputs
batch2 = inputs2
batch3 = inputs3

In [22]:
model.generate3 = torchdynamo.optimize("inductor")(model.generate)

In [37]:
infres = model.generate3(
    input_ids = batch["input_ids"], 
    attention_mask = batch["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint_fake, # use longest fake trie to warm up
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
print(tokenizer.batch_decode(prediction))
print(probs)

['<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad></s>', '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad></s><pad>', '<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad></s><pad><pad>']
tensor([0., 0., 0.], device='cuda:0')


In [24]:
%%timeit
infres = model.generate3(
    input_ids = batch["input_ids"], 
    attention_mask = batch["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint, 
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
tokenizer.batch_decode(prediction)

63.3 ms ± 632 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
infres = model.generate3(
    input_ids = batch2["input_ids"], 
    attention_mask = batch2["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint, 
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
tokenizer.batch_decode(prediction)

63.2 ms ± 264 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%%timeit
infres = model.generate3(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint, 
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
tokenizer.batch_decode(prediction)

72.9 ms ± 789 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [27]:
infres = model.generate3(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint, 
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
tokenizer.batch_decode(prediction), probs

(['<pad> education & office supplies > cutting supplies > scissors</s><pad>',
  '<pad> education & office supplies > cutting supplies > letter opener</s><pad>',
  '<pad> education & office supplies > cutting supplies > utility knife</s>'],
 tensor([0., 0., 0.], device='cuda:0'))

In [28]:
infres = model.generate(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    num_beams = 3, 
    num_return_sequences = 3, 
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint, 
    output_scores=True, return_dict_in_generate=True
)
prediction = infres.sequences
probs = infres.sequences_scores.exp()
tokenizer.batch_decode(prediction), probs

(['<pad> education & office supplies > cutting supplies > scissors</s><pad>',
  '<pad> education & office supplies > cutting supplies > letter opener</s><pad>',
  '<pad> education & office supplies > cutting supplies > utility knife</s>'],
 tensor([0., 0., 0.], device='cuda:0'))

# constrained greedy

In [29]:
model.generate4 = torchdynamo.optimize("inductor")(model.generate)

In [36]:
print(tokenizer.batch_decode(model.generate4(
    input_ids = batch["input_ids"], 
    attention_mask = batch["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint_fake
)))

['<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']


In [31]:
%%timeit
tokenizer.batch_decode(model.generate4(
    input_ids = batch["input_ids"], 
    attention_mask = batch["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint
))

35.2 ms ± 806 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [32]:
%%timeit
tokenizer.batch_decode(model.generate4(
    input_ids = batch2["input_ids"], 
    attention_mask = batch2["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint
))

34.7 ms ± 341 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [33]:
%%timeit
tokenizer.batch_decode(model.generate4(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint
))

32.2 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [34]:
tokenizer.batch_decode(model.generate4(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint
))

['<pad> beauty & health > shaving & hair removal > waxing</s>']

In [35]:
tokenizer.batch_decode(model.generate(
    input_ids = batch3["input_ids"], 
    attention_mask = batch3["attention_mask"],
    do_sample = False, 
    length_penalty = 0, 
    max_new_tokens = 50 - 1, # HACK: T5 adds pad token in the beginning
    prefix_allowed_tokens_fn=constraint
))

['<pad> beauty & health > shaving & hair removal > waxing</s>']