In [2]:
import re
from typing import Optional

from jaxtyping import Float, Int
import torch
from torch import Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, StoppingCriteria

In [3]:
model_ref = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_ref)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_ref)

In [9]:
# Pre-compute which vocabulary IDs correspond to full words
full_word_re = re.compile(r"^[ĠĊ]|[^\w]")
# There are some punct tokens that are not full words -- contractions. TODO figure out how the tokenizer deals with this. For now we'll have a manual list of exceptions
subword_exceptions = ["'s", "'ve", "'re", "'m", "'ll", "'d", "'t"]
# subword_exception_ids = [tokenizer.encode(subword)[0] for subword in subword_exceptions]

full_word_mask = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
for token, idx in tokenizer.get_vocab().items():
    full_word_mask[idx] = (idx == tokenizer.eos_token_id or bool(full_word_re.match(token))) and (token not in subword_exceptions)

class FullWordStoppingCriteria(StoppingCriteria):
    def __call__(self,
                 input_ids: Int[Tensor, "batch seq_len"],
                 scores: Float[Tensor, "batch vocab_size"]) -> bool:
        """
        Stop when the last generated token is a full word.
        """
        return bool(full_word_mask[input_ids[-1]].all().item())

In [10]:
inputs = tokenizer("Can we please", return_tensors="pt")
top_p = 0.95
cum_p = 0.0
fringe_size = 100

results, result_scores = [], []

with torch.no_grad():
    outputs = model(inputs["input_ids"])
    next_token_logits = outputs.logits[:, -1, :]
    next_token_logprobs = next_token_logits.log_softmax(dim=-1)

    # Initialize fringe with top k tokens concatenated to context
    indices = torch.topk(next_token_logits, k=fringe_size, dim=-1).indices.T
    # fringe: collection of token sequences to be expanded
    fringe = torch.cat([inputs["input_ids"].repeat(fringe_size, 1), indices], dim=-1)
    # fringe_logprobs: accumulated log-probability within each fringe element
    fringe_logprobs = next_token_logprobs[0, indices.squeeze()]

    j = 0
    while cum_p < top_p:
        print(cum_p)
        if j > 2:
            break
        j += 1

        print(f"Step {j}, fringe size: {len(fringe)}")

        if j > 1:
            # Stage/pop fringe elements which show new full words
            to_stage = full_word_mask[fringe[:, -1]]

            # sequences_to_stage = fringe[to_stage, :-1].tolist()
            # DEV: track the last token too
            sequences_to_stage = fringe[to_stage].tolist()

            results.extend(sequences_to_stage)
            result_scores.extend(fringe_logprobs[to_stage].tolist())

            # Update cumulative probability
            cum_p += fringe_logprobs[to_stage].exp().sum().item()

            # Remove fringe elements which show new full words
            fringe = fringe[~to_stage]
            fringe_logprobs = fringe_logprobs[~to_stage]

        if len(fringe) == 0:
            print("Empty fringe. This ain't good but we gotta stop. Try increasing K.")
            break

        # Repopulate fringe by expanding the remaining sequences
        # Draw `repop_k`-many max-probability continuation tokens for each sequence in fringe
        repop_k = 20
        # repop_outputs.sequences will be len(fringe) * repop_k
        repop_outputs = model.generate(fringe, return_dict_in_generate=True, output_scores=True,
                                       pad_token_id=tokenizer.eos_token_id,
                                       max_length=fringe.shape[-1] + 1,  # just one more token
                                       num_beams=repop_k, num_return_sequences=repop_k)
        assert repop_outputs.scores[0].shape == (len(fringe) * repop_k, model.config.vocab_size)

        # Draw `fringe_size` of the continuations with greatest total log probability
        new_fringe_size = min(fringe_size, len(repop_outputs.sequences))
        _, repop_indices = torch.topk(repop_outputs.sequences_scores, k=new_fringe_size, dim=-1)
        repop_tokens = repop_outputs.sequences[repop_indices, -1]
        repop_logprobs = repop_outputs.sequences_scores[repop_indices]

        # Locate originating sequences for topk items
        originating_sequences_ = torch.arange(len(fringe)).unsqueeze(-1).repeat(1, repop_k).flatten()
        originating_sequences = originating_sequences_[repop_indices]
                                   
        fringe = torch.cat([fringe[originating_sequences], repop_tokens.unsqueeze(-1)], dim=-1)
        # TODO shouldn't add this yet. it may be the logprob of a sentinel. we want to use the sentinel content but not incorporate its logprob
        fringe_logprobs = fringe_logprobs[originating_sequences] + repop_logprobs

# TODO merge dupes


0.0
Step 1, fringe size: 100
0.0
Step 2, fringe size: 100
0.4680291712284088
Step 3, fringe size: 40
Empty fringe. This ain't good but we gotta stop. Try increasing K.


In [33]:
# for seq, prob in zip(fringe, fringe_probs):
#     print(tokenizer.convert_ids_to_tokens(seq), prob)

In [11]:
for seq, prob in zip(results, result_scores):
    print(tokenizer.convert_ids_to_tokens(seq), prob)

['Can', 'Ġwe', 'Ġplease', 'Ċ', 'Ċ'] -5.343081951141357
['Can', 'Ġwe', 'Ġplease', 'Ġrefrain', 'Ġfrom'] -5.750398635864258
['Can', 'Ġwe', 'Ġplease', 'Ġspread', 'Ġthe'] -5.864449977874756
['Can', 'Ġwe', 'Ġplease', 'Ġreach', 'Ġout'] -5.881028652191162
['Can', 'Ġwe', 'Ġplease', 'Ġnote', 'Ġthat'] -4.9881744384765625
['Can', 'Ġwe', 'Ġplease', 'Ġcontinue', 'Ġto'] -5.307211875915527
['Can', 'Ġwe', 'Ġplease', 'Ġrefer', 'Ġto'] -6.079874038696289
['Can', 'Ġwe', 'Ġplease', 'Ġthank', 'Ġyou'] -6.370788097381592
['Can', 'Ġwe', 'Ġplease', 'Ġdo', 'Ġnot'] -3.342820167541504
['Can', 'Ġwe', 'Ġplease', 'Ġif', 'Ġyou'] -6.364151477813721
['Can', 'Ġwe', 'Ġplease', '."', 'Ċ'] -6.490191459655762
['Can', 'Ġwe', 'Ġplease', 'Ġenable', 'ĠJavaScript'] -4.78279447555542
['Can', 'Ġwe', 'Ġplease', 'Ġensure', 'Ġthat'] -6.545234203338623
['Can', 'Ġwe', 'Ġplease', 'Ġunderstand', 'Ġthat'] -6.279345989227295
['Can', 'Ġwe', 'Ġplease', 'Ġinform', 'Ġyou'] -6.3013014793396
['Can', 'Ġwe', 'Ġplease', 'Ġreport', 'Ġany'] -5.70466470

In [None]:
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True,
                         num_beams=5, num_return_sequences=5, stopping_criteria=[FullWordStoppingCriteria()])

In [11]:
outputs

BeamSearchDecoderOnlyOutput(sequences=tensor([[15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    13,   314,  1842,   683,    13],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    11,   290,   314,  1842,   683],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           284,   711,   351,   340,    13,   314,  1842,   284,   711,   351],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           284,   711,   351,   683,    13,   314,  1842,   284,   711,   351],
        [15496,    11,   616,  3290,   318, 13779,    11,   290,   314,  1842,
           683,    13,   314,  1842,   683,    13,   314,  1842,   683,    11]]), sequences_scores=tensor([-0.8121, -0.8173, -0.8743, -0.8833, -0.9129]), scores=(tensor([[ -3.0483,  -6.7191, -10.8251,  ..., -20.6052, -17.7893,  -7.3971],
        [

In [13]:
full_word_mask[outputs.sequences[0]]

tensor([False,  True,  True, False, False,  True,  True,  True, False, False,
         True, False, False, False,  True, False, False, False,  True, False])

In [14]:
for seq in outputs.sequences:
    print(tokenizer.decode(seq))

Hello, my dog is cute, and I love him. I love him. I love him.
Hello, my dog is cute, and I love him. I love him, and I love him
Hello, my dog is cute, and I love to play with it. I love to play with
Hello, my dog is cute, and I love to play with him. I love to play with
Hello, my dog is cute, and I love him. I love him. I love him,


In [10]:
tokenizer.convert_ids_to_tokens(outputs.sequences[0])

['Hello',
 ',',
 'Ġmy',
 'Ġdog',
 'Ġis',
 'Ġcute',
 ',',
 'Ġand',
 'ĠI',
 'Ġlove',
 'Ġto',
 'Ġplay',
 'Ġwith',
 'Ġit',
 '.',
 'ĠI',
 'Ġlove',
 'Ġto',
 'Ġplay',
 'Ġwith']