# The power of constrained language models.

**Why and how to build constrained language models with a custom beam search algorithm. A guide with Hugging Face code.**

Pre-trained generative language models (such as OpenAI's GPT2 and GPT3) or seq2seq models (such as T5 or the recently released T0) generate free-flowing natural language. This means that their output sentences can have any shape. To get the most value out of these models, we would sometimes like the outputs to follow a certain structure. In this notebook I will show you how to achieve this and gain more value out of your language model using a custom beam search algorithm.

This notebook is ment to accompany a blogpost. Read the blogpost to fully understand the benefits of a custom beam search algorithm.

## Installing the necessary packages

## The code







First we will download our model and corresponding tokenizer. I'm currently using GPT2 but any generative language model which has a `beam_search` method should work.

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('/data2/bdlml/models/LLMs/openchat_3.5')
model = AutoModelForCausalLM.from_pretrained("/data2/bdlml/models/LLMs/openchat_3.5")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [2]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch
import re
from transformers import LogitsProcessor
import numpy as np

We define a helper function: `set_scores_to_inf_for_banned_tokens`. Ignore this for now

In [3]:
# src: https://huggingface.co/transformers/v4.1.1/_modules/transformers/generation_logits_process.html

def set_scores_to_neginf_for_banned_tokens(scores, banned_tokens):
    """
    Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
    list of list of banned tokens to ban in the format [[batch index, vocabulary position],...

    Args:
        scores: logits distribution of shape (batch size, vocabulary size)
        banned_tokens: list of list of tokens to ban of length (batch_size)
    """
    banned_mask_list = []
    for idx, batch_banned_tokens in enumerate(banned_tokens):
        for token in batch_banned_tokens:
            banned_mask_list.append([idx, token])
    if not banned_mask_list:
        return scores

    banned_mask = torch.LongTensor(banned_mask_list)
    indices = torch.ones(len(banned_mask))

    banned_mask = (
        torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
    )
    scores = scores.masked_fill(banned_mask, -float("inf"))
    return scores

We implement the `LogitsProcessor` class to get our desired effect. Our custom class should implement the `__call__` method of `LogitsProcessor`.

This method will be called during each step of the beam search algorithm. The method takes as input the `input_ids` sequence of the partially generated beam and the `scores` of the next possible tokens.

By manipulating these `scores` based on the tokens present in the `input_ids`, we can control the structure of the generated sentence.

We implement two custom `LogitsProcessor` classes: `EvenLogits` and `OddLogits`. The `EvenLogits` class makes sure that all generated tokens contain an even amount of characters. The `OddLogits` do vice versa

In both implementations, we achieve this by dynamically creating a list of all tokens we are not allowed to output and then setting the corresponding `scores` to `-inf` using our helper function `set_scores_to_inf_for_banned_tokens`.

In [4]:


class EvenLogits(LogitsProcessor):
  def __call__(self, input_ids, scores):

    banned_tokens = []
    for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      elementwise_length = np.vectorize(len)
      keys = np.array(list(tokenizer.vocab.keys()))
      values = np.array(list(tokenizer.vocab.values()))

      # indexes of tokens that are too long
      indexes = np.where(elementwise_length(keys) % 2 == 0)[0]

      banned_tokens.append(values[indexes])

    scores = set_scores_to_neginf_for_banned_tokens(scores, banned_tokens)
    return scores

class OddLogits(LogitsProcessor):
  def __call__(self, input_ids, scores):

    banned_tokens = []
    for beam_indexx, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
      elementwise_length = np.vectorize(len)
      keys = np.array(list(tokenizer.vocab.keys()))
      values = np.array(list(tokenizer.vocab.values()))

      # indexes of tokens that are too long
      indexes = np.where(elementwise_length(keys) % 2 != 0)[0]

      banned_tokens.append(values[indexes])

    scores = set_scores_to_neginf_for_banned_tokens(scores, banned_tokens)
    return scores


We use our custom `LogitsProcessor` classes during the beam search algorithm by passing them to the `logits_processor` attribute of the `beam_search` method of our model.

To devise another logits_processor, I devised 2 logits_process (on the similar lines as EvenLogits) - NumbersLogits (to predict numbers), and NoNumbersLogits (to predict not numbers), howeever, I see opposite in the results, I think I know the reason but want to confirm it

In [5]:
class NumbersLogits(LogitsProcessor):
    
    def __call__(self, input_ids, scores):
        banned_tokens = []
        #print('input_ids ', input_ids.shape)
        #print('scores ', scores.shape)
        #print('output ', tokenizer.decode(input_ids[0]))
        for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
            elementwise_length = np.vectorize(len)
            keys = np.array(list(tokenizer.vocab.keys()))
            values = np.array(list(tokenizer.vocab.values()))
            #indexes = np.where(elementwise_length(keys) % 2 == 0)[0]
            number_indexes=[]
            for i, elem in enumerate(keys):
                try:
                    elem_int=int(elem)
                    number_indexes.append(i)
                except Exception as e:
                    a=1
            #print('number_indexes ', len(number_indexes) )#, values[number_indexes])
            banned_tokens.append(values[number_indexes])
        scores = set_scores_to_neginf_for_banned_tokens(scores, banned_tokens)
        
        for scr in scores:
            num_occurrences=0
            for s in scr:
                if s==-float("inf"):
                    #print(s)
                    num_occurrences+=1
            #print('num_occurrences ', num_occurrences)
        return scores
    
class NoNumbersLogits(LogitsProcessor):
    
    def __call__(self, input_ids, scores):
        banned_tokens = []
        #print('input_ids ', input_ids.shape)
        #print('scores ', scores.shape)
        #print('output ', tokenizer.decode(input_ids[0]))
        for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):
            elementwise_length = np.vectorize(len)
            keys = np.array(list(tokenizer.vocab.keys()))
            values = np.array(list(tokenizer.vocab.values()))
            #indexes = np.where(elementwise_length(keys) % 2 == 0)[0]
            number_indexes=[]
            for i, elem in enumerate(keys):
                try:
                    elem_int=int(elem)
                except Exception as e:
                    a=1
                    number_indexes.append(i)
            #print('number_indexes ', len(number_indexes) )#, values[number_indexes])
            banned_tokens.append(values[number_indexes])
        scores = set_scores_to_neginf_for_banned_tokens(scores, banned_tokens)
        
        for scr in scores:
            num_occurrences=0
            for s in scr:
                if s==-float("inf"):
                    #print(s)
                    num_occurrences+=1
            #print('num_occurrences ', num_occurrences)
        return scores

In [6]:


# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5
max_length=20

# the prompt to continue
prompt = 'My cute dog'

# tokenizing the prompt

print('Without logits_processor ', '\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

print('\n')   
print('With OddLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([OddLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

print('\n')
print('With EvenLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([EvenLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t','even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

    


Without logits_processor  

beam 0: <s> My cute doggie is a 10-year-old Chihuahua named 	 even  1 odd  5
beam 1: <s> My cute doggie is a 10-year-old Chihuahua. 	 even  2 odd  3
beam 2: <s> My cute doggie is a 10-year-old Chihuahua mix 	 even  1 odd  5
beam 3: <s> My cute doggie is a 10-year-old Chihuahua- 	 even  2 odd  3
beam 4: <s> My cute doggie is a 10-year-old Chihuahua who 	 even  1 odd  5


With OddLogits  



  torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()


beam 0: <s> My cute dog has a big personality and a big personality needs a big dog bed for comfort and 	 even  1 odd  16
beam 1: <s> My cute dog has a big personality and a big personality needs a big dog bed for a small 	 even  1 odd  16
beam 2: <s> My cute dog has a big personality and a big personality needs a big dog bed for a dog 	 even  1 odd  16
beam 3: <s> My cute dog has a big personality and a big personality needs a big dog bed for a comfortable 	 even  1 odd  16
beam 4: <s> My cute dog has a big personality and a big personality needs a big dog bed for a big 	 even  1 odd  16


With EvenLogits  

beam 0: <s> My cute dog is 100% purebred American Akita. He is  	 even  9 odd  0
beam 1: <s> My cute dog is 100% purebred American Akita. He's 	 even  7 odd  0
beam 2: <s> My cute dog is 100% purebred American Akita. He’s 	 even  7 odd  0
beam 3: <s> My cute dog is 100% purebred American Akita. He is very 	 even  9 odd  0
beam 4: <s> My cute dog is 100% purebred. He's 10 	 even  5

In [13]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=20

# the prompt to continue
prompt = 'Bollywood actor Aamir Khan is'


# tokenizing the prompt


print('Without logits_processor ', '\n')


prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())
        

print('With OddLogits ', '\n')
print('\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([OddLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())


print('With EvenLogits ', '\n')
print('\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([EvenLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t','even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())


Without logits_processor  

beam 0: <s> Bollywood actor Aamir Khan is known for his perfectionism and dedication towards his 	 even  2 odd  7
beam 1: <s> Bollywood actor Aamir Khan is all set to make his digital debut with the upcoming 	 even  5 odd  6
beam 2: <s> Bollywood actor Aamir Khan is all set to return to the small screen as the 	 even  6 odd  5
beam 3: <s> Bollywood actor Aamir Khan is all set to return to the small screen with his 	 even  6 odd  5
beam 4: <s> Bollywood actor Aamir Khan is all set to return to the small screen with a 	 even  6 odd  5
With OddLogits  



beam 0: <s> Bollywood actor Aamir Khan is known for his fitness and the actor never hesit 	 even  1 odd  9
beam 1: <s> Bollywood actor Aamir Khan is known for his fitness and the actor keeps sharing his 	 even  1 odd  10
beam 2: <s> Bollywood actor Aamir Khan is known for his fitness and the actor has now taken 	 even  1 odd  10
beam 3: <s> Bollywood actor Aamir Khan is known for his fitness and the actor has 

In [8]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=50

# the prompt to continue
prompt = 'Generate some adjectives which describe a good nature of a human'

# tokenizing the prompt


print('Without logits_processor ', '\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']



# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())
        
print('\n')

print('With OddLogits ', '\n')


prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([OddLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

print('\n')

print('With EvenLogits ', '\n')


beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([EvenLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t','even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())


Without logits_processor  

beam 0: <s> Generate some adjectives which describe a good nature of a human being.

1. Altruistic
2. Benevolent
3. Charitable
4. Compassionate
5. Empathetic
6 	 even  4 odd  3
beam 1: <s> Generate some adjectives which describe a good nature of a human being.

1. Altruistic
2. Benevolent
3. Compassionate
4. Empathetic
5. Generous
6 	 even  4 odd  3
beam 2: <s> Generate some adjectives which describe a good nature of a human being.

1. Altruistic
2. Benevolent
3. Caring
4. Compassionate
5. Empathetic
6 	 even  4 odd  3
beam 3: <s> Generate some adjectives which describe a good nature of a human being.

1. Altruistic
2. Benevolent
3. Compassionate
4. Empathetic
5. Forgiving
 	 even  4 odd  3
beam 4: <s> Generate some adjectives which describe a good nature of a human being.

1. Altruistic
2. Benevolent
3. Charitable
4. Compassionate
5. Caring
6. 	 even  3 odd  4


With OddLogits  

beam 0: <s> Generate some adjectives which describe a good nature of a human b

In [9]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=50

# the prompt to continue
prompt = 'Tell me about Artificial Intelligence'

# tokenizing the prompt
prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']

print('Without logits_processor ')

# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())
    
print('\n')

print('With OddLogits ')


beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([OddLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}', '\t', 'even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

print('\n')
print('With EvenLogits ')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([EvenLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}', '\t','even ' ,even_criteria_words.sum(), 'odd ' ,len(even_criteria_words)- even_criteria_words.sum())

    


Without logits_processor 
beam 0: <s> Tell me about Artificial Intelligence (AI)

Artificial Intelligence (AI) refers to the simulation of human intelligence in machines programmed to think like humans and mimic their actions. The term may also be applied to any machine that exhib 	 even  21 odd  12
beam 1: <s> Tell me about Artificial Intelligence (AI) and Machine Learning (ML)

Artificial Intelligence (AI) refers to the simulation of human intelligence in machines programmed to think like humans and mimic their actions. The term may also be 	 even  21 odd  10
beam 2: <s> Tell me about Artificial Intelligence (AI)

Artificial Intelligence (AI) refers to the simulation of human intelligence in machines programmed to think like humans and mimic their actions. The technology is designed to replicate human tasks such as 	 even  22 odd  10
beam 3: <s> Tell me about Artificial Intelligence (AI)

Artificial Intelligence (AI) refers to the simulation of human intelligence in machines programm

In [10]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=30

# the prompt to continue
prompt = 'Give me examples of even numbers'

print('Without logits_processor ', '\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')
    
print('\n')
print('With NoNumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NoNumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}')

print('\n')
print('With NumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')

    


Without logits_processor  

beam 0: <s> Give me examples of even numbers that are not divisible by 2.

Here are some examples of even numbers that are not divisible
beam 1: <s> Give me examples of even numbers that are not divisible by 2.

Even numbers are numbers that are divisible by 2.
beam 2: <s> Give me examples of even numbers between 10 and 30.

Even numbers are numbers that are divisible by 2.
beam 3: <s> Give me examples of even numbers that are not divisible by 2.

Even numbers are those that are divisible by 2.
beam 4: <s> Give me examples of even numbers between 10 and 30.

Even numbers are those that are divisible by 2.


With NoNumbersLogits  

beam 0: <s> Give me examples of even numbers12345678910111213141516
beam 1: <s> Give me examples of even numbers12345678901234567890123
beam 2: <s> Give me examples of even numbers12345678910121314151617
beam 3: <s> Give me examples of even numbers12345678901112131415161
beam 4: <s> Give me examples of even numbers12345678912345678

In [11]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=30

# the prompt to continue
prompt = 'Generate some numbers'

print('Without logits_processor ', '\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')
    
print('\n')
print('With NoNumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NoNumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}')

print('\n')
print('With NumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')

    


Without logits_processor  

beam 0: <s> Generate some numbers

```python
import random

def generate_numbers(n):
    numbers = []
    for
beam 1: <s> Generate some numbers

```python
import random

def generate_numbers():
    numbers = []
    for _ in
beam 2: <s> Generate some numbers

```python
import random

def generate_numbers(n):
    numbers = [random.rand
beam 3: <s> Generate some numbers

```python
import random

def generate_random_numbers(n):
    numbers = []

beam 4: <s> Generate some numbers

```python
import random

def generate_numbers():
    numbers = []
    for i in


With NoNumbersLogits  

beam 0: <s> Generate some numbers12345678910111213141516171
beam 1: <s> Generate some numbers12345678901234567890123456
beam 2: <s> Generate some numbers12345678912345678912345678
beam 3: <s> Generate some numbers12345678909876543210987654
beam 4: <s> Generate some numbers10000000000000000000000000


With NumbersLogits  

beam 0: <s> Generate some numbers

```python
import random

de

In [12]:
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    StoppingCriteriaList,
    MaxLengthCriteria
)
import torch

# how many beams to track during the Viterbi algorithm
num_beams = 5
# how many beams to return after the algorithm
num_return_beams = 5

max_length=30

# the prompt to continue
prompt = 'Do whatever you want '

print('Without logits_processor ', '\n')

prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
prompt_tokenized = prompt_tokenized['input_ids']


# instantiating a BeamSearchScorer
beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')
    
print('\n')
print('With NoNumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NoNumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)

  print(f'beam {index}: {output}')

print('\n')
print('With NumbersLogits ', '\n')

beam_scorer = BeamSearchScorer(
    batch_size = prompt_tokenized.shape[0],
    num_beams = num_beams,
    num_beam_hyps_to_keep = num_return_beams,
    device=model.device
)

logits_processor = LogitsProcessorList([NumbersLogits()])

# running beam search using our custom LogitsProcessor
generated = model.beam_search(
    torch.cat([prompt_tokenized] * num_beams),
    beam_scorer,
    logits_processor = logits_processor,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
)

# printing the output beams
for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized)
  oup=re.sub('<s> '+ prompt, '',output )
  elementwise_length=np.vectorize(len)
  even_criteria_words=(elementwise_length(oup.split(' '))%2==0)
  
  print(f'beam {index}: {output}')

    


Without logits_processor  

beam 0: <s> Do whatever you want 100% of the time.

That’s what I’ve been telling myself lately.


beam 1: <s> Do whatever you want 100% of the time.

That’s what I’ve been doing lately.

I
beam 2: <s> Do whatever you want 100% of the time.

That’s what I tell myself every day.

I’
beam 3: <s> Do whatever you want 100% of the time.

That’s what I’ve been doing for the last 1
beam 4: <s> Do whatever you want 100% of the time.

That’s what I’ve been doing for the past 1


With NoNumbersLogits  

beam 0: <s> Do whatever you want 100000000000000000000000
beam 1: <s> Do whatever you want 100020002000300040005000
beam 2: <s> Do whatever you want 100020000000000000000000
beam 3: <s> Do whatever you want 100000000000000000000001
beam 4: <s> Do whatever you want 100020002000000000000000


With NumbersLogits  

beam 0: <s> Do whatever you want 🤘🤘🤘🤘🤘🤘
beam 1: <s> Do whatever you want ������������������������
beam 2: <s> Do whatever you want 🤘🤘🤘🤘🤘

I'
beam 3: <s> Do wha