In [2]:
import re
from typing import Optional

from jaxtyping import Float, Int
import torch
from torch import Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_ref = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_ref)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_ref)

In [4]:
# Pre-compute which vocabulary IDs correspond to full words
full_word_re = re.compile(r"^[ĠĊ]|[^\w]")
full_word_mask = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
for token, idx in tokenizer.get_vocab().items():
    full_word_mask[idx] = idx == tokenizer.eos_token_id or bool(full_word_re.match(token))

class FullWordStoppingCriteria(StoppingCriteria):
    def __call__(self,
                 input_ids: Int[Tensor, "batch seq_len"],
                 scores: Float[Tensor, "batch vocab_size"]) -> bool:
        """
        Stop when the last generated token is a full word.
        """
        return bool(full_word_mask[input_ids[-1]].all().item())

In [26]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
top_p = 0.95
cum_p = 0.0
top_k = 100

results, result_scores = [], []

with torch.no_grad():
    outputs = model(inputs["input_ids"])
    next_token_logits = outputs.logits[:, -1, :]
    # todo overflow
    next_token_probs = next_token_logits.softmax(dim=-1)

    # # Top-p filter
    # sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
    # cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Initialize fringe with top k tokens
    indices = torch.topk(next_token_logits, k=top_k, dim=-1).indices.T
    fringe = torch.cat([inputs["input_ids"].repeat(top_k, 1), indices], dim=-1)
    fringe_probs = next_token_probs[0, indices.squeeze()]
    print(fringe.shape, fringe_probs.shape)

    j = 0
    while cum_p < top_p:
        print(cum_p)
        if j > 2:
            break
        j += 1
        # For fringe elements which show new full words, add them to results
        new_full_words = full_word_mask[fringe[:, -1]]

        # TODO reduce dupes
        results.extend(fringe[new_full_words, :-1].tolist())

        # TODO this is the wrong probability -- we want the penultimate token
        result_scores.extend(fringe_probs[new_full_words].tolist())

        # Update cumulative probability
        cum_p += fringe_probs[new_full_words].sum().item()

        # Remove fringe elements which show new full words
        fringe = fringe[~new_full_words]
        fringe_probs = fringe_probs[~new_full_words]

        # Repopulate fringe
        repop_k = 20
        # repop_outputs.sequences will be len(fringe) * repop_k
        repop_outputs = model.generate(fringe, return_dict_in_generate=True, output_scores=True,
                                       max_length=fringe.shape[-1] + 1,
                                       num_beams=repop_k, num_return_sequences=repop_k)
        repop_next_token_probs = repop_outputs.scores[0][:, -1]
        repop_probs, repop_indices = torch.topk(repop_next_token_probs, k=top_k, dim=-1)
        repop_tokens = repop_outputs.sequences[repop_indices, -1]

        # Locate originating sequences for topk items
        originating_sequences_ = torch.arange(len(fringe)).unsqueeze(-1).repeat(1, repop_k).flatten()
        originating_sequences = originating_sequences_[repop_indices]
                                   
        fringe = torch.cat([fringe[originating_sequences], repop_tokens.unsqueeze(-1)], dim=-1)
        fringe_probs = fringe_probs[originating_sequences] * repop_probs.exp() # TODO product


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([100, 7]) torch.Size([100])
0.0
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True])


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0.9323692917823792
tensor([ True,  True,  True, False,  True,  True, False,  True,  True,  True,
        False, False,  True, False, False,  True, False, False,  True, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False, False, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True, False, False, False, False,  True,  True, False,  True, False,
        False, False, False, False,  True,  True, False, False,  True,  True])


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0.9327764491026755
tensor([ True,  True,  True, False,  True,  True,  True,  True,  True, False,
        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False,  True,  True,  True, False,
        False, False,  True, False,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True, False,  True,
         True, False,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True, False,  True, False,  True,  True, False,  True, False,
         True,  True,  True,  True,  True, False,  True, False, False, False,
         True, False,  True,  True,  True,  True,  True,  True,  True,  True])
0.9327764491093644


In [23]:
for seq, prob in zip(fringe, fringe_probs):
    print(tokenizer.convert_ids_to_tokens(seq), prob)

['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', 'Ċ'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', '?'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', 'Ġbut'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', 'ĊĊ'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', 'ĠâĢĶ'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', 'but'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', ','] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', '!'] tensor(8.5705e-21)
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute', 'âĢ', 'ľ', 'âĢĶ', 'âĢĶ', 'âĢĵ', 'âĢ¦', '['] tensor(8.5705e-21)
['Hello', 

In [27]:
for seq, prob in zip(results, result_scores):
    print(tokenizer.convert_ids_to_tokens(seq), prob)

['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.2846169173717499
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.21214528381824493
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.2009846419095993
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.04743831977248192
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.02154489792883396
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.009362243115901947
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.007975026965141296
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.007465483620762825
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.007376144640147686
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.007116660941392183
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.006821037735790014
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.006216147914528847
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.0054660337045788765
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.005319549702107906
['Hello', ',', 'Ġmy', 'Ġdog', 'Ġis', 'Ġcute'] 0.005311277229338884
[

In [None]:
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True,
                         num_beams=5, num_return_sequences=5, stopping_criteria=[FullWordStoppingCriteria()])

In [11]:
outputs

BeamSearchDecoderOnlyOutput(sequences=tensor([[15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    13,   314,  1842,   683,    13],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    11,   290,   314,  1842,   683],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           284,   711,   351,   340,    13,   314,  1842,   284,   711,   351],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           284,   711,   351,   683,    13,   314,  1842,   284,   711,   351],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    13,   314,  1842,   683,    11]]), sequences_scores=tensor([-0.8121, -0.8173, -0.8743, -0.8833, -0.9129]), scores=(tensor([[ -3.0483,  -6.7191, -10.8251,  ..., -20.6052, -17.7893,  -7.3971],
        [

In [13]:
full_word_mask[outputs.sequences[0]]

tensor([False,  True,  True, False, False,  True,  True,  True, False, False,
         True, False, False, False,  True, False, False, False,  True, False])

In [14]:
for seq in outputs.sequences:
    print(tokenizer.decode(seq))

Hello, my dog is cute, and I love him. I love him. I love him.
Hello, my dog is cute, and I love him. I love him, and I love him
Hello, my dog is cute, and I love to play with it. I love to play with
Hello, my dog is cute, and I love to play with him. I love to play with
Hello, my dog is cute, and I love him. I love him. I love him,


In [10]:
tokenizer.convert_ids_to_tokens(outputs.sequences[0])

['Hello',
 ',',
 'Ġmy',
 'Ġdog',
 'Ġis',
 'Ġcute',
 ',',
 'Ġand',
 'ĠI',
 'Ġlove',
 'Ġto',
 'Ġplay',
 'Ġwith',
 'Ġit',
 '.',
 'ĠI',
 'Ġlove',
 'Ġto',
 'Ġplay',
 'Ġwith']