#### Implementation of decoding methods from scratch

In [2]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = 'gpt2-large'

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = 'left' # required for parallelism (default is set to right padding for some reason)

model.to(device)
model.eval()

2025-08-15 15:51:03.395601: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755273063.418433     117 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755273063.425310     117 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [4]:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

# end of sequence token is the end of document token in GPT-2 
eod_token_id = tokenizer.eos_token_id

In [5]:
prompt = "The quick brown fox"

inputs = tokenizer(prompt, return_tensors='pt').to(device)

outputs = model(**inputs)
logits = outputs.logits

last_logits = logits[-1][-1]

#### Greedy decoding

In [11]:
token_id = torch.argmax(last_logits, dim=-1)
tokenizer.decode(token_id)

' jumps'

#### Pure sampling

In [12]:
probs = F.softmax(last_logits, dim=0)
idx = probs.multinomial(num_samples=1)
tokenizer.decode(idx) # idx is the token_id

' jumps'

#### Top-k sampling

In [None]:
logits_k = torch.topk(last_logits, 10, dim=-1)

for token_id in logits_k.indices:
    print(tokenizer.decode(token_id))

 jumps
 jumped
 jump
,
 gets
 leapt
 is
 leaps
 was
 jumping


In [33]:
with torch.no_grad():
    probs = F.softmax(logits_k.values, dim=0)
    idx = logits_k.values.multinomial(num_samples=1)
    
tokenizer.decode(logits_k.indices[idx])

' is'

#### Nucleus sampling

In [None]:
with torch.no_grad():
    probs = F.softmax(last_logits, dim=0)
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumul_probs = torch.cumsum(sorted_probs, dim=0)

    p=0.9 # Top-p with p=0.9
    for i in range(sorted_probs.size()[0]):
        if cumul_probs[i]>p:
            sorted_probs = sorted_probs[:i]
            break
    
    sorted_probs = sorted_probs/sum(sorted_probs)
    idx = sorted_probs.multinomial(num_samples=1)
    token_id = sorted_indices[idx]

tokenizer.decode(token_id)

' jumps'