<a href="https://colab.research.google.com/github/brownsloth/transformers_concepts_notebooks/blob/main/transformers_7_decoding_strategies_in_text_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Decoding strategies: Control the generated text

1. Greedy Search
2. Beam Search
3. Random sampling (top-p and top-k sampling)

In [None]:
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'gpt2-medium' #autoregressive text generator

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

## 1. Greedy search decoding: To produce short sequences where factuality is importamt

In [None]:
import pandas as pd

time_steps = 8
choices_per_step = 5

def get_next_token_greedy_search(input_txt, input_ids):
  iterations = []
  with torch.no_grad():
    for _ in range(time_steps):
      iteration = dict()
      iteration['Input'] = tokenizer.decode(input_ids[0])
      output = model(input_ids=input_ids)
      print('Size of logits: ',output.logits.size())
      next_token_logits = output.logits[0,-1,:]
      next_token_probs = torch.softmax(next_token_logits, dim=-1)

      sorted_ind = torch.argsort(next_token_probs, dim=-1, descending=True)
      for choice_ind in range(choices_per_step):
        token_idx = sorted_ind[choice_ind]
        token_prob = next_token_probs[token_idx].cpu().numpy()

        token_choice = (
            f"{tokenizer.decode(token_idx)} ({100*token_prob:.2f})"
        )

        iteration[f'Choice {choice_ind+1}'] = token_choice
      # input_ids = torch.cat()
      iterations.append(iteration)
      input_ids = torch.cat([input_ids, sorted_ind[None, 0, None]], dim=-1)

  return pd.DataFrame(iterations)

In [None]:
input_txt = 'Bitcoin will be'
input_ids = tokenizer(input_txt, return_tensors='pt')['input_ids'].to(device)

get_next_token_greedy_search(input_txt, input_ids)

In [None]:
## Greedy search using OOTB feature
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=False)
print(tokenizer.decode(output[0]))#first batch

In [None]:
input_txt = "In today's rapidly evolving digital landscape, the substantial advance and rapid growth of data presents companies and their operations with a set of opportunities from different sources that can profoundly impact their competitiveness and success. The literature suggests that data can be considered a hidden weapon that fosters decision-making while determining a company's success in a rapidly changing market. Data are also used to support most organizational activities and decisions. As a result"
input_ids = tokenizer(input_txt, return_tensors='pt')['input_ids'].to(device)
time_steps = 128
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=False)
print(tokenizer.decode(output[0]))#first batch

## 2. Beam Search Decoding

https://youtu.be/KPtna8FahZ8?si=KB5MxI5XNrvo7Emv&t=16466 --> get the exact implementation of beam search decoding from here!

In [None]:
input_txt = "In today's rapidly evolving digital landscape, the substantial advance and rapid growth of data presents companies and their operations with a set of opportunities from different sources that can profoundly impact their competitiveness and success. The literature suggests that data can be considered a hidden weapon that fosters decision-making while determining a company's success in a rapidly changing market. Data are also used to support most organizational activities and decisions. As a result"
input_ids = tokenizer(input_txt, return_tensors='pt')['input_ids'].to(device)
time_steps = 128
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=False, num_beams=5)
print(tokenizer.decode(output[0]))#first batch

In [None]:
## To reduce repition while maintaining diversity
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=False, num_beams=5, no_repeat_ngram_size=2)
print(tokenizer.decode(output[0]))#first batch

#### Beam search With sampling

With temp: high temp accentuates rare words and thus improves diversity but reduces coherence

In [None]:
## To reduce repition while maintaining diversity
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=True, temperature=2.0, top_k=0)
print(tokenizer.decode(output[0]))#first batch

#### Using top-k or top-p WITH SAMPLING we can reduce the vocab size to meaningful possible tokens

top-k : choose k most probable tokens at each timestep
top-p: choose those many most probable tokens at each timestep which have cumulative prob >=p

In [None]:
output = model.generate(input_ids, max_new_tokens=time_steps, do_sample=True, top_p=0.90)
print(tokenizer.decode(output[0]))#first batch