In [1]:
import transformer_lens
from transformer_lens import HookedTransformer
import torch

In [28]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2")
model.eval()
tokenizer = model.tokenizer

Loaded pretrained model gpt2 into HookedTransformer


In [43]:
# Greedy sampling
def gen_greedy(model, tokenizer, seq, max_new_tokens=5):
    tokens = model.to_str_tokens(seq)
    i = 0
    
    while i < max_new_tokens:
        logits = model(tokens)
        last_token_logits = logits[:, 1, :][-1]
        soft_logits = torch.softmax(last_token_logits, dim=-1)

        next_token = model.to_string(torch.argmax(soft_logits, dim=-1).item())
        tokens.append(next_token)
        i += 1
        
    return "".join(tokens)

In [44]:
gen_greedy(model, tokenizer, 'jack and jill went up the hill. jack game some', max_new_tokens=10)

'<|endoftext|>jack and jill went up the hill. jack game some of the first,\nThe first,\nThe'

In [45]:
# Temperature based
def gen_temp(model, tokenizer, seq, max_new_tokens=5, temperature=1):
    tokens = model.to_str_tokens(seq)
    i = 0
    
    while i < max_new_tokens:
        logits = model(tokens)      
        last_token_logits = logits[:, 1, :][-1]
        soft_logits = torch.softmax(last_token_logits/temperature, dim=-1)

        next_token = model.to_string(torch.multinomial(soft_logits, num_samples=1).item())
        tokens.append(next_token)
        i += 1
        
    return "".join(tokens)

In [46]:
gen_temp(model, tokenizer, "jack and jill went up the hill. jack gave some", max_new_tokens=10, temperature=0.5)

'<|endoftext|>jack and jill went up the hill. jack gave some of the "I am I\'m so that is'

In [76]:
# top-k search
def top_k_search(model, tokenizer, seq, max_new_tokens=5, top_k=5, temperature=0.8):
    tokens = model.to_str_tokens(seq)
    i = 0
    
    while i < max_new_tokens:
        logits = model(tokens)      
        last_token_logits = logits[:, 1, :][-1]
        
        soft_topk, topk_idx = torch.topk(last_token_logits, k=top_k)
        soft_logits = torch.softmax(soft_topk/temperature, dim=-1)
        chosen_idx = torch.multinomial(soft_logits, num_samples=1).item()
        next_token_id = topk_idx[chosen_idx].item()
        
        tokens.append(model.to_string(next_token_id))
        i += 1
        
    return "".join(tokens)


In [77]:
top_k_search(model, tokenizer, "jack and jill went up the hill. jack gave some")

'<|endoftext|>jack and jill went up the hill. jack gave some people who was not a'

In [None]:
# Beam search [incomplete]
def beam_search(model, tokenizer, seq, max_new_tokens=5, n_beam=3):
    tokens = model.to_str_tokens(seq)
    i = 0
    
    while i < max_new_tokens:
        beam_count = 0
        while n_beams:
            logits = model(tokens)    
            last_token_logits = logits[:, 1, :][-1]
            soft_logits = torch.softmax(last_token_logits, dim=-1)
            
            next_token = model.to_string(torch.multinomial(soft_logits, num_samples=1).item())
            tokens.append(next_token)
        i += 1
        
    return "".join(tokens)