# The power of constrained language models.

**Why and how to build constrained language models with a custom beam search algorithm. A guide with Hugging Face code.**

Pre-trained generative language models (such as OpenAI's GPT2 and GPT3) or seq2seq models (such as T5 or the recently released T0) generate free-flowing natural language. This means that their output sentences can have any shape. To get the most value out of these models, we would sometimes like the outputs to follow a certain structure. In this notebook I will show you how to achieve this and gain more value out of your language model using a custom beam search algorithm.

This notebook is ment to accompany a blogpost. Read the blogpost to fully understand the benefits of a custom beam search algorithm.

## Installing the necessary packages

In [271]:
!pip install transformers
!pip install torch
!pip install numpy

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


UnboundLocalError: local variable 'child' referenced before assignment

## The code







First we will download our model and corresponding tokenizer. I'm currently using GPT2 but any generative language model which has a `beam_search` method should work.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained("gpt2")

  from .autonotebook import tqdm as notebook_tqdm
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 8c8d0de8-2050-4e00-98a4-c715a9cb55f6)')' thrown while requesting HEAD https://huggingface.co/gpt2/resolve/main/tokenizer_config.json
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: bd6e8064-1ffb-4748-bf1f-794e2b763ef6)')' thrown while requesting HEAD https://huggingface.co/gpt2/resolve/main/config.json


We define a helper function: `set_scores_to_inf_for_banned_tokens`. Ignore this for now

In [None]:
# src: https://huggingface.co/transformers/v4.1.1/_modules/transformers/generation_logits_process.html

def set_scores_to_inf_for_banned_tokens(scores, banned_tokens):
    """
    Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
    list of list of banned tokens to ban in the format [[batch index, vocabulary position],...

    Args:
        scores: logits distribution of shape (batch size, vocabulary size)
        banned_tokens: list of list of tokens to ban of length (batch_size)
    """
    banned_mask_list = []
    for idx, batch_banned_tokens in enumerate(banned_tokens):
        for token in batch_banned_tokens:
            banned_mask_list.append([idx, token])
    if not banned_mask_list:
        return scores

    banned_mask = torch.LongTensor(banned_mask_list)
    indices = torch.ones(len(banned_mask))

    banned_mask = (
        torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
    )
    scores = scores.masked_fill(banned_mask, -float("inf"))
    return scores

We implement the `LogitsProcessor` class to get our desired effect. Our custom class should implement the `__call__` method of `LogitsProcessor`.

This method will be called during each step of the beam search algorithm. The method takes as input the `input_ids` sequence of the partially generated beam and the `scores` of the next possible tokens.

By manipulating these `scores` based on the tokens present in the `input_ids`, we can control the structure of the generated sentence.

We implement two custom `LogitsProcessor` classes: `EvenLogits` and `ABCLogits`. The `EvenLogits` class makes sure that all generated tokens contain an even amount of characters. The `ABCLogits` class is a bit more complex: it makes sure that our generated tokens follow an 'a -> b -> c' pattern. The first token starts with an 'a', the second with a 'b', the third with a 'c', the fourth again with an 'a', etc.

In both implementations, we achieve this by dynamically creating a list of all tokens we are not allowed to output and then setting the corresponding `scores` to `-inf` using our helper function `set_scores_to_inf_for_banned_tokens`.

In [None]:
from transformers import LogitsProcessor
import numpy as np

class CyrillicLogits(LogitsProcessor):
  def __call__(self, input_ids, scores):

    banned_tokens = []
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      elementwise_length = np.vectorize(len)
      keys = np.array(list(tokenizer.vocab.keys()))
      values = np.array(list(tokenizer.vocab.values()))

      # indexes of tokens that are too long
      indexes = np.where(elementwise_length(keys) % 2 == 0)[0]

      banned_tokens.append(values[indexes])

    scores = set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
    return scores

class EvenLogits(LogitsProcessor):
  def __call__(self, input_ids, scores):

    banned_tokens = []
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      elementwise_length = np.vectorize(len)
      keys = np.array(list(tokenizer.vocab.keys()))
      values = np.array(list(tokenizer.vocab.values()))

      # indexes of tokens that are too long
      indexes = np.where(elementwise_length(keys) % 2 == 0)[0]

      banned_tokens.append(values[indexes])

    scores = set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
    return scores

class ABCLogits(LogitsProcessor):
  def __init__(self, vocab):
    """
    vocab is a dictionary where the keys are tokens
    and the values are the corresponding ids.
    """
    # create an array of tokens
    # remove the 'Ġ' token (used to represent a blank space in the tokenizer)
    self.keys = list(tokenizer.vocab.keys())
    index_to_pop = self.keys.index('Ġ')
    self.keys.pop(index_to_pop)
    self.keys = np.array(self.keys)

    # create an array of ids
    # also remove the 'Ġ' token
    self.values = list(tokenizer.vocab.values())
    self.values.pop(index_to_pop)
    self.values = np.array(self.values)

    # vectorized function used to get the first character of a token
    # ignores leading whitespaces and 'Ġ' tokens
    first_char = lambda x: x.strip('Ġ ')[0].lower()
    self.first_char = np.vectorize(first_char)

    # get the indexes of all IDs that do not start with the given letter
    not_a_indexes = np.where(self.first_char(self.keys) != 'a')
    not_b_indexes = np.where(self.first_char(self.keys) != 'b')
    not_c_indexes = np.where(self.first_char(self.keys) != 'c')

    # create sets of tokens that do not start with 'a', 'b' or 'c'
    self.not_a_values = self.values[not_a_indexes]
    self.not_b_values = self.values[not_b_indexes]
    self.not_c_values = self.values[not_c_indexes]

  def __call__(self, input_ids, scores):
    banned_tokens = []
    # for every beam (partially generated sentence)
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      # get the last token of this beam
      last_word = tokenizer.decode(beam_input_ids[-1])
      # get the first character of this last token
      starting_char = self.first_char(last_word)
      # if the last token starts with 'a',
      # ban all words that do not start with 'b', etc.
      if starting_char == 'a':
        banned_tokens.append(self.not_b_values)
      elif starting_char == 'b':
        banned_tokens.append(self.not_c_values)
      elif starting_char == 'c':
        banned_tokens.append(self.not_a_values)
      else:
        banned_tokens.append(self.not_a_values)
    # set the scores of all banned tokens over the beams to -inf
    scores = set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
    return scores

We use our custom `LogitsProcessor` classes during the beam search algorithm by passing them to the `logits_processor` attribute of the `beam_search` method of our model.

In the code blow, we use GPT2 to continue to prompt 'My cute dog is a' in the 'a -> b -> c' pattern.

In [None]:
from transformers import (
    BeamScorer,
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteria,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 10
# how many beams to return after the algorithm
num_return_beams = 10

# the prompt to continue
prompt = 'My cute dog is a'

# tokenizing the prompt
prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']

# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

# instantiating a list of LogitsProcessor instances
# using our custom ABCLogits class
logits_processor = LogitsProcessorList([ABCLogits(tokenizer.vocab)])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=12)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  print(f'beam {index}: {output}')

beam 0: My cute dog is a bit confused about being called a bitch
beam 1: My cute dog is a bit confused about being called a baby
beam 2: My cute dog is a big cat and big cats are big
beam 3: My cute dog is a bit confused about being called a bunny
beam 4: My cute dog is a bit confused about being called a big
beam 5: My cute dog is a bit confused about being called a boy
beam 6: My cute dog is a bit confused about being called a b
beam 7: My cute dog is a bit confused about being called a bad
beam 8: My cute dog is a bit confused about being confused about being
beam 9: My cute dog is a bit confused about being called a black


Notice how the output of the model adheres to this structure without us having to provide the model with additional examples. I particularly like how the model still manages to make some coherent sentences despite these constraints.

Let's now add an additional constraint: all tokens should have an even amount of characters.

In [None]:
# list of preferred words
preferred = ['big']
preferred = [w.lower() for w in preferred]

class PreferredWordsLogits(LogitsProcessor):
  def __init__(self, tokenizer):
    """
    vocab is a dictionary where the keys are tokens
    and the values are the corresponding ids.
    """
    self.eos_token_id = tokenizer.encode('<|endoftext|>')
    self.choose_eof_scores = torch.full((len(tokenizer.vocab), ), -torch.inf)
    self.choose_eof_scores[self.eos_token_id] = 1

    self.choose_eos_token_ids = set({int(self.eos_token_id[0])})
    for token, token_id in tokenizer.vocab.items():
      if token[-1]=='.':
        self.choose_eos_token_ids.add(token_id) 

    self.preferred_token_ids = []
    for token, token_id in tokenizer.vocab.items():
      for word in preferred:
        if word in token:
          self.preferred_token_ids.append(token_id)
    self.preferred_token_mask = torch.full((len(tokenizer.vocab), ), 0)
    self.preferred_token_mask[self.preferred_token_ids] = 1


  def __call__(self, input_ids, scores):

    # for every beam (partially generated sentence)
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):

      last_token_id = beam_input_ids[-1]

      if int(last_token_id) in self.choose_eos_token_ids:
        scores[beam_index] = self.choose_eof_scores
        continue

      # DOESNT WORK!!

      
      # scores[beam_index][self.preferred_token_mask] *= 500
      scores[beam_index][self.preferred_token_mask] = -torch.inf 

      # # get the last word of this beam
      # last_word = ""
      # num_prev_tokens = 0
      # while num_prev_tokens < len(beam_input_ids) and (len(last_word) == 0 or last_word[0] not in 'Ġ '):
      #   num_prev_tokens += 1
      #   last_token_id = beam_input_ids[-num_prev_tokens]
      #   last_token = tokenizer.decode(last_token_id)
      #   last_word = last_token + last_word
      #   # sometimes the first previous token is ' '. restart to find the true previous word:
      #   if last_word in 'Ġ ':
      #     last_word = ""

      # if last_word.lower() in preferred:        
      #   print(beam_index, scores.shape, )
      #   scores[beam_index] -= 100
        
    return scores

prompt = 'My cute dog is a'

# tokenizing the prompt
prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']

num_beams = 10
num_beam_hyps_to_keep = 10

# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_beam_hyps_to_keep,
    device=model.device,
)

logits_processor = LogitsProcessorList([
                                        PreferredWordsLogits(tokenizer), 
                                        # ABCLogits(tokenizer.vocab)
                                        ])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria = StoppingCriteriaList([
       MaxLengthCriteria(max_length=25)
       ]),
    pad_token_id = tokenizer.encode('<|endoftext|>')[0]
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  print(f'beam {index}: {output}')

beam 0: My cute dog is a good friend.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
beam 1: My cute dog is a good friend of mine.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
beam 2: My cute dog is a big deal.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
beam 3: My cute dog is a good dog.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
beam 4: My cute dog is a bit of a pain in the ass.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
beam 5: My cute dog is a bit of a mystery

In [None]:
p =  PreferredWordsLogits(tokenizer)
len(set(p.preferred_token_ids)), p.preferred_token_ids

(9, [28849, 33985, 18203, 5749, 1263, 27102, 14261, 25015, 4094])

In [None]:
beam_scorer.is_done

tensor(False)

We again notice how the generated part of the sentence adheres to the ' a -> b -> c' constraint. Additionally, all generated words have an even amount of tokens.

These are quite heavy constraints, causing the model to give less coherent outputs.

In [None]:
tokenizer.special_tokens_map['eos_token']

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>'}

In [None]:
tokenizer.encode('<|endoftext|>')

[50256]

In [277]:
import re
from tqdm import tqdm
from typing import Optional

english200 = ["hello", "the", "be", "of", "and", "a", "to", "in", "he", "have", "it", "that", "for", "they", "I", "with", "as", "not", "on", "she", "at", "by", "this", "we", "you", "do", "but", "from", "or", "which", "one", "would", "all", "will", "there", "say", "who", "make", "when", "can", "more", "if", "no", "man", "out", "other", "so", "what", "time", "up", "go", "about", "than", "into", "could", "state", "only", "new", "year", "some", "take", "come", "these", "know", "see", "use", "get", "like", "then", "first", "any", "work", "now", "may", "such", "give", "over", "think", "most", "even", "find", "day", "also", "after", "way", "many", "must", "look", "before", "great", "back", "through", "long", "where", "much", "should", "well", "people", "down", "own", "just", "because", "good", "each", "those", "feel", "seem", "how", "high", "too", "place", "little", "world", "very", "still", "nation", "hand", "old", "life", "tell", "write", "become", "here", "show", "house", "both", "between", "need", "mean", "call", "develop", "under", "last", "right", "move", "thing", "general", "school", "never", "same", "another", "begin", "while", "number", "part", "turn", "real", "leave", "might", "want", "point", "form", "off", "child", "few", "small", "since", "against", "ask", "late", "home", "interest", "large", "person", "end", "open", "public", "follow", "during", "present", "without", "again", "hold", "govern", "around", "possible", "head", "consider", "word", "program", "problem", "however", "lead", "system", "set", "order", "eye", "plan", "run", "keep", "face", "fact", "group", "play", "stand", "increase", "early", "course", "change", "help", "line"]

In [None]:
vocab = tokenizer.vocab
pattern = re.compile(r'^[a-zA-ZĠ]+$')

latin_vocab = {token: token_id for token, token_id in vocab.items() if pattern.match(token)}

In [278]:
class TokenTrieNode:
    terminal_word_id: Optional[int] = None
    terminal_score: Optional[int] = None
    children_score: Optional[int] = None
    children: dict[int, 'TokenTrieNode'] # token_id : TokenTrieNode of postfix

    def __init__(self):
        self.children = {}
        self.eos_token_id = tokenizer.encode('<|endoftext|>')

    def add_postfix(self, postfix: str, word_id: int, depth: int = 1, verbose: bool = False):
        if verbose: print('.'*depth, postfix)
        if len(postfix) == 0:
            self.terminal_word_id = word_id
            if verbose: print('+'*depth)
            return True

        for token, token_id in latin_vocab.items():
            if postfix.startswith(token):
                new_postfix = postfix[len(token):]
                if verbose: print('.'*depth, token_id, f"({token}/{tokenizer.decode(token_id)})")
                
                if token_id in self.children:
                    # try to let the child store it
                    self.children[token_id].add_postfix(new_postfix, word_id, depth=depth+1)
                else:
                    # create a candidate child. store it only if this token sequence works it out.
                    candidate = TokenTrieNode()
                    if candidate.add_postfix(new_postfix, word_id, depth=depth+1, verbose=verbose):
                        self.children[token_id] = candidate
                        if verbose: print('+'*depth)
                    elif verbose:
                        print('-'*depth)
        return len(self.children) > 0

    def print(self, prefix: str = '', prefix_tokens: list[int] = [], depth: int = 0):
        if self.terminal_word_id is not None:
            print(self.terminal_word_id, prefix, prefix_tokens)
        # print(self, len(self.children))
        # if depth == 10: 
            # print('='*50)
            # return
        for token_id, child in self.children.items():
            # print('-', child)
            child.print(prefix + '-' + tokenizer.decode(token_id), [token_id] + prefix_tokens, depth=depth+1)

    
data = english200
# data = english200[:5]
# data = ['AND']

root = TokenTrieNode()
for word_id, word in tqdm(enumerate(data)):
    root.add_postfix('Ġ' + word.lower(), word_id)
    root.add_postfix('Ġ' + word.capitalize(), word_id)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
201it [02:14,  1.49it/s]


In [279]:
root.print()

8 - -he [258, 220]
0 - -he-ll-o [78, 297, 258, 220]
0 - -he-l-l-o [78, 75, 75, 258, 220]
0 - -he-l-lo [5439, 75, 258, 220]
199 - -he-l-p [79, 75, 258, 220]
0 - -he-llo [18798, 258, 220]
122 - -he-re [260, 258, 220]
122 - -he-r-e [68, 81, 258, 220]
176 - -he-ad [324, 258, 220]
176 - -he-a-d [67, 64, 258, 220]
199 - -he-lp [34431, 258, 220]
0 - -hell-o [78, 12758, 220]
0 - -hello [31373, 220]
8 - -h-e [68, 71, 220]
0 - -h-e-ll-o [78, 297, 68, 71, 220]
0 - -h-e-l-l-o [78, 75, 75, 68, 71, 220]
0 - -h-e-l-lo [5439, 75, 68, 71, 220]
199 - -h-e-l-p [79, 75, 68, 71, 220]
0 - -h-e-llo [18798, 68, 71, 220]
122 - -h-e-re [260, 68, 71, 220]
122 - -h-e-r-e [68, 81, 68, 71, 220]
176 - -h-e-ad [324, 68, 71, 220]
176 - -h-e-a-d [67, 64, 68, 71, 220]
199 - -h-e-lp [34431, 68, 71, 220]
0 - -h-el-l-o [78, 75, 417, 71, 220]
0 - -h-el-lo [5439, 417, 71, 220]
199 - -h-el-p [79, 417, 71, 220]
0 - -h-ell-o [78, 695, 71, 220]
0 - -h-ello [11109, 71, 220]
9 - -h-ave [1015, 71, 220]
9 - -h-a-v-e [68, 85, 64, 71,

In [None]:

light = TokenTrieNode()
verbose = False
for word_id, word in enumerate(data[:5]):
    light.add_postfix('Ġ' + word.lower(), word_id, verbose=verbose)
    light.add_postfix('Ġ' + word.capitalize(), word_id, verbose=verbose)
if verbose:
    light.print()

tokens = [24988, 589]
tokens = [128, 254, 1169]

query_data = data.copy()[:5]
# query_data = [word.capitalize() for word in query_data]
query_data = [' ' + word for word in query_data]
# query_data = [word.lower() for word in query_data]

for i, tokens in enumerate(tokenizer(query_data).input_ids):
    print(f"Calling with |{tokens}| |{tokenizer.decode(tokens)}|")
    print(light.scores_for_word(tokens))

In [286]:
def scores_for_word(self, tokens_postfix):
    # after walking the tree, we obtained the tokens that we prefer to be set next, to stay in vocabulary:
    if len(tokens_postfix) == 0: return list(self.children.keys()) + [eos_token_id]
    
    print(f" Called with {tokens_postfix}")
    
    # continue walking down the tree
    curr_token = tokens_postfix[0]
    if curr_token in self.children:
        return scores_for_word(self.children[curr_token], tokens_postfix=tokens_postfix[1:])

    print('Found nothing.', self.children.keys())

    # this token sequence is not in vocabulary
    return []


test_data = [' ' + word for word in english200[:5]]
for test_word in test_data:
    print(test_word)
    scores_for_word(root, tokenizer(test_word))

 hello
 Called with {'input_ids': [23748], 'attention_mask': [1]}
Found nothing. dict_keys([220, 932, 339, 289, 23748, 5968, 5783, 679, 5053, 367, 18435, 294, 256, 262, 383, 309, 536, 307, 275, 347, 1355, 267, 286, 440, 3226, 290, 257, 281, 843, 317, 1052, 284, 1675, 1312, 287, 554, 314, 387, 423, 8192, 9398, 23284, 340, 632, 326, 28110, 1320, 11511, 329, 277, 19434, 1114, 376, 484, 1119, 45967, 351, 20868, 266, 11759, 2080, 370, 40648, 355, 1081, 645, 407, 299, 1892, 1400, 399, 319, 1550, 264, 673, 427, 311, 911, 1375, 379, 1629, 416, 2750, 428, 770, 356, 775, 331, 27406, 345, 25455, 921, 575, 288, 466, 360, 2141, 809, 475, 9842, 887, 1216, 422, 8400, 1305, 3574, 9734, 393, 1471, 543, 348, 9022, 854, 530, 1881, 24486, 561, 10928, 22173, 477, 435, 978, 1439, 39716, 481, 2561, 5187, 612, 10811, 12634, 1318, 910, 473, 13816, 10318, 508, 5338, 285, 17266, 787, 337, 6889, 15841, 6669, 618, 483, 1649, 5792, 1275, 460, 269, 327, 1680, 6488, 2146, 6941, 517, 3461, 3125, 4270, 611, 1002, 582, 

In [None]:
class TrieNode:
    children: dict[int, 'TrieNode'] # token_id : TrieNode of postfix
    terminal_word_id: Optional[int] = None
    # terminal_score: Optional[int] = None
    # children_score: Optional[int] = None
    in_vocab_token_ids: Optional[set[int]] = None

    def __init__(self, tokenizer):
        self.children = {}
        self.eos_token_id = tokenizer.encode('<|endoftext|>')[0]
        self.vocab = tokenizer.vocab

    def add_word(self, word: str, word_id: int, depth: int = 1, verbose: bool = False):
        """
        word: a word (or postfix) of a word which is to be added to the trie (sub-trie). expected in lowercase.
        """
        char = word[0]
        assert ord('a') <= ord(char) <= ord('z'), f"Expected a lowercase character, got {char}"

        if len(word) == 1:
            self.terminal_word_id = word_id
            return

        if char not in self.children:
            self.children[char] = TrieNode()
        self.children[char].add_word(word[1:], word_id, depth=depth+1, verbose=verbose)

    def add_in_vocab_token_ids(self):
        """
        Adds all token ids that are in vocabulary to the set self.in_vocab_token_ids.
        """
        self.in_vocab_token_ids = set()
        for token, token_id in self.vocab.items():
            # iterate down the trie for all chars in the token, then call recursively.
            curr = self
            token_out_of_vocab
            for char in token:
                if char in curr.children:
                    curr = curr.children[char]
                else:
                    break

In [None]:
tokenizer.encode()

In [None]:
word = data[0]
new_postfix = 'ġ' + word.lower()

['hello']

In [None]:
root, len(root.children)

(<__main__.Node at 0x7f7c872b28b0>, 431)

In [None]:
import numpy as np

In [None]:
data = ['googleplex', 'between', 'developer']


for i, tokens in enumerate(tokenizer(data).input_ids[:10]):
    if len(tokens) > 1:
        print(i, data[i], tokens, tokenizer.decode(tokens))

In [None]:
{k:v for k,v in tokenizer.vocab.items() if v==128}

{'Ä': 128}

In [None]:
tokenizer.decode([128]), \
tokenizer.decode([128, 254, 38]), \
tokenizer.decode([128, 380]), \
tokenizer.decode([128, 254, 14150]), \
tokenizer.decode([128, 38])


('�', 'ĠG', '�ri', 'Ġhave', '�G')

In [22]:
import re

In [23]:
text = """
Advantages:
State of Health (SoH) Monitoring: EIS can provide detailed information on the internal condition of the battery, including degradation mechanisms. This is crucial for predicting the remaining useful life of the battery.

Early Fault Detection: By detecting changes in impedance, EIS can identify issues such as electrode delamination or electrolyte degradation before they lead to significant performance drops.

Optimizing Charging Protocols: EIS data can be used to optimize charging algorithms, ensuring that the battery is charged in a way that maximizes lifespan and safety.

Disadvantages:
Complexity: EIS is a complex technique that requires sophisticated hardware and software, potentially increasing the cost and complexity of the BMS.

Data Interpretation: The data obtained from EIS can be challenging to interpret and requires advanced algorithms and expertise to convert it into actionable insights.

Slow Response Time: EIS measurements can be relatively slow compared to other monitoring methods, which might limit its usefulness for real-time applications.

Overall, while EIS can provide valuable insights for a BMS, its complexity and slow response time might limit its practical applications in certain types of battery systems. However, for systems where long-term health and safety are paramount, and where the additional complexity can be justified, EIS can be an extremely useful tool.
"""

In [24]:
text_split = text.split()

# Define a function to strip special characters from each word
def strip_special_characters(word):
    # Use regex to replace non-alphabetic characters with an empty string
    return re.sub(r'[^a-zA-Z]', '', word)

# Use the function to clean up each word and then filter out any empty strings or non-alphabetic words
cleaned_words = [strip_special_characters(word) for word in text_split]
cleaned_words = [word for word in cleaned_words if word.isalpha()]

cleaned_words

['Advantages',
 'State',
 'of',
 'Health',
 'SoH',
 'Monitoring',
 'EIS',
 'can',
 'provide',
 'detailed',
 'information',
 'on',
 'the',
 'internal',
 'condition',
 'of',
 'the',
 'battery',
 'including',
 'degradation',
 'mechanisms',
 'This',
 'is',
 'crucial',
 'for',
 'predicting',
 'the',
 'remaining',
 'useful',
 'life',
 'of',
 'the',
 'battery',
 'Early',
 'Fault',
 'Detection',
 'By',
 'detecting',
 'changes',
 'in',
 'impedance',
 'EIS',
 'can',
 'identify',
 'issues',
 'such',
 'as',
 'electrode',
 'delamination',
 'or',
 'electrolyte',
 'degradation',
 'before',
 'they',
 'lead',
 'to',
 'significant',
 'performance',
 'drops',
 'Optimizing',
 'Charging',
 'Protocols',
 'EIS',
 'data',
 'can',
 'be',
 'used',
 'to',
 'optimize',
 'charging',
 'algorithms',
 'ensuring',
 'that',
 'the',
 'battery',
 'is',
 'charged',
 'in',
 'a',
 'way',
 'that',
 'maximizes',
 'lifespan',
 'and',
 'safety',
 'Disadvantages',
 'Complexity',
 'EIS',
 'is',
 'a',
 'complex',
 'technique',
 't

In [66]:
from transformers import GPT2TokenizerFast, AutoTokenizer
from itertools import product

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token="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")

# Initialize the tokenizer
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# Your word
word = "EXAMpLE"

# Initialize lists of prefixes and postfixes
prefixes = [
    '', ' ', '.', ',', '!', '?', '"', '(', ')', '-', '—', ':', ';', 
    '\n', '\t', "'", '“', '”', '‘', '’', '...', '—'
]

postfixes = [
    '', 
    ' ', '.', ',', '!', '?', '"', '(', ')', '-', '—', ':', ';', 
    '\n', '\t', "'", '“', '”', '‘', '’', '...'
]

# Extended list of prefixes
prefixes = [
    '', ' ', '.', ',', '!', '?', '"', '(', ')', '-', '—', ':', ';',
    '\n', '\t', "'", '“', '”', '‘', '’', '...', '—',
    '«', '»', '$', '€', '£', '+', '=', '%', '#', '@',
    '{', '}', '[', ']', '/', '\\', '*', '<', '>'
]

# Extended list of postfixes
postfixes = [
    '', ' ', '.', ',', '!', '?', '"', '(', ')', '-', '—', ':', ';',
    '\n', '\t', "'", '“', '”', '‘', '’', '...', '..',
    # 'st', 'nd', 'rd', 'th', 'cm', 'kg', 'mi', 'lb', 
    '.txt', '.doc',
    '>', '<'
]

# Function to tokenize prefixes/postfixes and get unique tokens
def get_unique_tokens(tokenizer, items):
    unique_tokens = set()
    for item in items:
        # Tokenize each item individually
        tokens = tokenizer.tokenize(item)
        # Update the unique tokens set with the new tokens
        unique_tokens.update(tokens)
    return unique_tokens

# Get unique tokens for prefixes and postfixes
special_tokens = get_unique_tokens(tokenizer, prefixes + postfixes)

trouble_prefix={}
trouble_postfix={}
for word in cleaned_words[:]:
    # Tokenize the word in different formats and print the tokens that directly refer to the word
    seen={}
    for pre in prefixes:
        for post in postfixes:
            for word in [word.lower(), word.capitalize()]:
                formatted_word = f"{pre}{word}{post}"
                tokens = tokenizer.tokenize(formatted_word)
                word_tokens = [t for t in tokens 
                               if t not in special_tokens 
                               and sum(c in t for c in word) > 0]
                tokens = [t for t in tokens]
                # print(f"Format: '{fmt}' -> Tokens: {tokens} -> Word Tokens: {word_tokens}")

                key = tuple(word_tokens)
                if key in seen:
                    seen[key].append((pre, post))
                else:
                    seen[key] = [(pre, post)]
    
    if len(seen)!=4 or any(len(v) not in (52,1014) for v in seen.values()):
        for k,v in seen.items():
            print(k,len(v))
            assert len(k)>0, f"{k} {v}"
            
            pre, post = k[0][0], k[-1][-1]
            if pre not in [word[0].lower(), word[0].capitalize()]:
                if pre in trouble_prefix:
                    trouble_prefix[pre].append(k)
                else:
                    trouble_prefix[pre] = [k]
            if post not in [word[-1].lower(), word[-1].capitalize()]:
                if post in trouble_prefix:
                    trouble_prefix[post].append(k)
                else:
                    trouble_prefix[post] = [k]

        print('-'*50)


In [63]:
seen

{('▁can',): [('', ''),
  ('', ' '),
  ('', '.'),
  ('', ','),
  ('', '!'),
  ('', '?'),
  ('', '"'),
  ('', '('),
  ('', ')'),
  ('', '-'),
  ('', '—'),
  ('', ':'),
  ('', ';'),
  ('', '\n'),
  ('', '\t'),
  ('', "'"),
  ('', '“'),
  ('', '”'),
  ('', '‘'),
  ('', '’'),
  ('', '...'),
  ('', '..'),
  ('', '.txt'),
  ('', '.doc'),
  ('', '>'),
  ('', '<'),
  (' ', ''),
  (' ', ' '),
  (' ', '.'),
  (' ', ','),
  (' ', '!'),
  (' ', '?'),
  (' ', '"'),
  (' ', '('),
  (' ', ')'),
  (' ', '-'),
  (' ', '—'),
  (' ', ':'),
  (' ', ';'),
  (' ', '\n'),
  (' ', '\t'),
  (' ', "'"),
  (' ', '“'),
  (' ', '”'),
  (' ', '‘'),
  (' ', '’'),
  (' ', '...'),
  (' ', '..'),
  (' ', '.txt'),
  (' ', '.doc'),
  (' ', '>'),
  (' ', '<')],
 ('▁Can',): [('', ''),
  ('', ' '),
  ('', '.'),
  ('', ','),
  ('', '!'),
  ('', '?'),
  ('', '"'),
  ('', '('),
  ('', ')'),
  ('', '-'),
  ('', '—'),
  ('', ':'),
  ('', ';'),
  ('', '\n'),
  ('', '\t'),
  ('', "'"),
  ('', '“'),
  ('', '”'),
  ('', '‘'),
  ('', 

In [65]:
seen[('Can',)]

[('.', ''),
 ('.', ' '),
 ('.', '.'),
 ('.', ','),
 ('.', '!'),
 ('.', '?'),
 ('.', '"'),
 ('.', '('),
 ('.', ')'),
 ('.', '-'),
 ('.', '—'),
 ('.', ':'),
 ('.', ';'),
 ('.', '\n'),
 ('.', '\t'),
 ('.', "'"),
 ('.', '“'),
 ('.', '”'),
 ('.', '‘'),
 ('.', '’'),
 ('.', '...'),
 ('.', '..'),
 ('.', '.txt'),
 ('.', '.doc'),
 ('.', '>'),
 ('.', '<'),
 (',', ''),
 (',', ' '),
 (',', '.'),
 (',', ','),
 (',', '!'),
 (',', '?'),
 (',', '"'),
 (',', '('),
 (',', ')'),
 (',', '-'),
 (',', '—'),
 (',', ':'),
 (',', ';'),
 (',', '\n'),
 (',', '\t'),
 (',', "'"),
 (',', '“'),
 (',', '”'),
 (',', '‘'),
 (',', '’'),
 (',', '...'),
 (',', '..'),
 (',', '.txt'),
 (',', '.doc'),
 (',', '>'),
 (',', '<'),
 ('!', ''),
 ('!', ' '),
 ('!', '.'),
 ('!', ','),
 ('!', '!'),
 ('!', '?'),
 ('!', '"'),
 ('!', '('),
 ('!', ')'),
 ('!', '-'),
 ('!', '—'),
 ('!', ':'),
 ('!', ';'),
 ('!', '\n'),
 ('!', '\t'),
 ('!', "'"),
 ('!', '“'),
 ('!', '”'),
 ('!', '‘'),
 ('!', '’'),
 ('!', '...'),
 ('!', '..'),
 ('!', '.txt')

In [61]:
for k,v in trouble_prefix.items():
    print(k,len(v))

print()

for k,v in trouble_postfix.items():
    print(k,len(v))




In [48]:
filtered_vocab = {k: v for k, v in tokenizer.vocab.items() if any(c.isalpha() for c in k) and not k[0].isalpha()}
len(filtered_vocab), len(tokenizer.vocab)

(16313, 32000)

In [50]:
# Collect all unique first characters of keys in the filtered vocabulary
unique_first_chars = {token[0] for token in filtered_vocab.keys()}

# Display the unique first characters
print(unique_first_chars)


{'<', '▁'}


In [41]:
tokenizer.tokenize(R"""example'text)""")

['example', "'t", 'ext', ')']

## Conclusion

By writing our own version of the beam search algorithm, we are able to constrain the output of a pre-trained language model. This can be applied to generative models such as GPT2 and GPT3 and even seq2seq models such as T5 and T0. This is particularly useful when we want the output of our models to follow a certain pre-defined structure or adhere to a set of rules.

If you are interested in getting even more value out of your GPT model, check out [this](https://towardsdatascience.com/almost-no-data-and-no-time-unlocking-the-true-potential-of-gpt3-a-case-study-b4710ca0614a) post about how 'prompt engineering' can help you unlock more value out of your generative language models.