In [1]:
class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end_of_sequence = False
        self.idxs = set()

class Trie:
    def __init__(self):
        self.root = TrieNode()
        self.longest_sequence_length = 0

    def insert(self, sequence: list, idx: int):
        node = self.root
        node.idxs.add(idx)
        for num in sequence:
            if num not in node.children:
                node.children[num] = TrieNode()
            node = node.children[num]
            node.idxs.add(idx)
        node.is_end_of_sequence = True

        if len(sequence) > self.longest_sequence_length:
            self.longest_sequence_length = len(sequence)

    def starts_with(self, prefix: list) -> bool:
        node = self.root
        for num in prefix:
            if num not in node.children:
                return False
            node = node.children[num]
        return True

    def get_next(self, prefix: list) -> list:
        node = self.root
        for num in prefix:
            if num not in node.children:
                return []
            node = node.children[num]
        return list(node.children.keys())

    def get_idxs(self, prefix: list) -> set:
        node = self.root
        for num in prefix:
            if num not in node.children:
                return set()
            node = node.children[num]
        return node.idxs


t = Trie()
t.insert("hello", 0)
t.insert("hell", 1)
t.insert("help", 2)
t.insert("hola", 3)
print(t.starts_with("h"))
print(t.get_next("he"))
print(t.get_idxs("help"))

True
['l']
{2}


In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

def build_token_prefix_map(tokenizer):
    """
    Build a map from token to index using a Trie datastructure
    """
    token_map = Trie()
    for i in range(len(tokenizer)):
        try:
            s = tokenizer.decode([i])
        except:
            print(f"token id {i} not found in tokenizer")
            continue
        token_map.insert(s, i)  # handle duplicate token encodings
    return token_map


token_map = build_token_prefix_map(tokenizer)
print([tokenizer.decode(i) for i in token_map.get_idxs(" hell")])

[' hell', ' hella', ' hello', ' helle', ' hellen', ' heller']


In [25]:
def get_start_decoding(prompt_tokens: list[int]) -> list[tuple[int, list[int]]]:
    """
    Given encoded tokens, return the index of the start of token healing
    and the list of tokens that match the possible healing tokens.
    This builds the possible healing tokens by taking the longest subsequence
    that has matches, growing iteratively from the end of the prompt
    up to the max token length.

    Returns:
        list of tuples, with the first element being the index of the start of healing
        and the second element being the list of token ids that match the healing token.
    """
    subseq = ""
    matches = [(len(prompt_tokens), list(range(len(tokenizer))))]
    # matches = []
    i = len(prompt_tokens) - 1
    while len(subseq) < token_map.longest_sequence_length and i >= 0:
        subseq = tokenizer.decode(prompt_tokens[i:], skip_special_tokens=True)
        if token_map.starts_with(subseq):
            matches.append((i, list(token_map.get_idxs(prefix=subseq))))

        i -= 1
    # return matches in order of start index
    matches = sorted(matches, key=lambda x: x[0])
    return matches


sentence = r"SuppressWarningsSuppressWarningsSuppressWarnin"
encoded = tokenizer.encode(sentence)
print([tokenizer.decode([i]) for i in encoded])
matches = get_start_decoding(encoded)
print(len(matches))
for m in matches:
    print(f"start idx: {m[0]}, matches: {[tokenizer.decode([i]) for i in m[1]][:5]}")

4
start idx: 5, matches: ['integr', 'inafter', 'initialise', 'intention', 'interface']
start idx: 6, matches: ['<pad>', '<eos>', '<bos>', '<unk>', '<mask>']


In [5]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
model = model.cuda()

Downloading shards: 100%|██████████| 2/2 [00:20<00:00, 10.01s/it]
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:16<00:00,  8.43s/it]


In [87]:
import torch

def token_healing(model, tokenizer, matches, encoded, sample_constrained=False, sample_predictions=False):
    input_ids = torch.tensor([encoded])
    with torch.no_grad():
        outputs = model(input_ids.to(model.device))

    perplexities = []
    decoded_sequences = []
    for healing_window, (start_idx, token_ids) in enumerate(matches):
        input_ids = encoded[:start_idx]
        start_idx_logits = outputs.logits[0, start_idx - 1, :].cpu()

        mask = torch.full(start_idx_logits.shape, float("-inf"))
        mask[token_ids] = 0
        masked_logits = start_idx_logits + mask

        if sample_constrained:
            next_token_id = torch.multinomial(
                torch.softmax(masked_logits, dim=-1), 1
            ).item()
        
        else:
            # argmax mode
            next_token_id = torch.argmax(masked_logits).item()

        new_sequence = input_ids + [next_token_id]
        decoded_sequence = [tokenizer.decode([t]) for t in new_sequence]

        # Calculate perplexity based on probability of next tokens in the sequence
        window_logits = outputs.logits[0, start_idx-healing_window-1:start_idx, :].cpu()
        target_window_ids = encoded[start_idx-healing_window:start_idx] + [next_token_id]
        loss = torch.nn.functional.cross_entropy(
            window_logits, torch.tensor(target_window_ids)
        )
        perplexity = torch.exp(loss).item()
        perplexities.append(perplexity)
        decoded_sequences.append(decoded_sequence)
    
    print(decoded_sequences, perplexities)
    probabilities = 1.0 / torch.tensor(perplexities)
    probabilities = probabilities / torch.sum(probabilities)

    if sample_predictions:
        # sample using 
        chosen_sequence = decoded_sequences[torch.multinomial(probabilities, 1).item()]
    else:
        # argmax mode
        chosen_sequence = decoded_sequences[torch.argmax(probabilities).item()]
    
    return chosen_sequence

In [80]:
sentence = r"SuppressWarningsSuppressWarningsSuppressWarn"
encoded = tokenizer.encode(sentence)
matches = get_start_decoding(encoded)
token_healing(model, tokenizer, matches, encoded)



In [86]:
sentence = r"Appendz responsez toz everyz wordz inz yourz resp"
encoded = tokenizer.encode(sentence)
matches = get_start_decoding(encoded)
token_healing(model, tokenizer, matches, encoded)

[['<bos>', 'Append', 'z', ' response', 'z', ' to', 'z', ' every', 'z', ' word', 'z', ' inz', ' your', 'z', ' response'], ['<bos>', 'Append', 'z', ' response', 'z', ' to', 'z', ' every', 'z', ' word', 'z', ' inz', ' your', 'z', ' resp', 'one']] [24.03759002685547, 131.357177734375]


['<bos>',
 'Append',
 'z',
 ' response',
 'z',
 ' to',
 'z',
 ' every',
 'z',
 ' word',
 'z',
 ' inz',
 ' your',
 'z',
 ' response']

In [1]:
from vllm import LLM, SamplingParams

  from .autonotebook import tqdm as notebook_tqdm
2024-05-24 11:57:34,352	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

In [1]:
# model = LLM("google/gemma-2b", max_logprobs=len(tokenizer))

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
model = model.cuda()

  from .autonotebook import tqdm as notebook_tqdm
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.40s/it]


In [5]:
import torch
sentence = r"Appendz responsez toz everyz wordz inz yourz resp"
# out = model.generate(
#     [sentence],
#     SamplingParams(
#         max_tokens=1,
#         logprobs=len(tokenizer),
#         prompt_logprobs=len(tokenizer),
#     ),
# )
# out[0].prompt_logprobs

encoded = tokenizer.encode(sentence, return_tensors="pt")
encoded = encoded.cuda()

with torch.no_grad():
    outputs = model(encoded)

In [9]:
# outputs

"aaaabbb".removeprefix("b")

'aaaabbb'