In [7]:
import pandas as pd
import torch
import datasets
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = 'gpt2-xl'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [54]:
input_txt = 'Transformers are the'
input_ids = tokenizer(input_txt, return_tensors='pt')['input_ids'].to(device)
iterations = []
n_steps = 8
choices_per_step = 5

with torch.inference_mode():
    for step in range(n_steps):
        iteration = dict()
        iteration['Input'] = tokenizer.decode(input_ids[0]) # grab the first item to avoid error
        output = model(input_ids=input_ids) # predict input's and next token's logits
        next_token_logits = output.logits[0, -1, :] # next token's logits
        next_token_probs = torch.softmax(next_token_logits, dim=-1) # convert logits to softmax
        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
        for choice_idx in range(choices_per_step):
            token_id = sorted_ids[choice_idx]
            token_prob = next_token_probs[token_id].cpu().numpy()
            token_choice = f"{tokenizer.decode(token_id)} ({token_prob:.2f})%"
            iteration[f'Choice {choice_idx+1}'] = token_choice
        input_ids = torch.cat([input_ids, sorted_ids[None, None, 0]], dim=-1)
        
        iterations.append(iteration)
        
display(pd.DataFrame(iterations))

Unnamed: 0,Input,Choice 1,Choice 2,Choice 3,Choice 4,Choice 5
0,Transformers are the,most (0.09)%,only (0.05)%,best (0.05)%,Transformers (0.04)%,ultimate (0.02)%
1,Transformers are the most,popular (0.17)%,powerful (0.05)%,common (0.05)%,famous (0.04)%,successful (0.03)%
2,Transformers are the most popular,toy (0.11)%,toys (0.07)%,Transformers (0.07)%,of (0.05)%,and (0.04)%
3,Transformers are the most popular toy,line (0.34)%,in (0.18)%,of (0.12)%,brand (0.06)%,line (0.03)%
4,Transformers are the most popular toy line,in (0.46)%,of (0.15)%,", (0.05)%",on (0.04)%,ever (0.03)%
5,Transformers are the most popular toy line in,the (0.66)%,history (0.12)%,America (0.07)%,Japan (0.02)%,North (0.01)%
6,Transformers are the most popular toy line in the,world (0.69)%,United (0.05)%,history (0.04)%,US (0.04)%,U (0.02)%
7,Transformers are the most popular toy line in ...,", (0.40)%",. (0.31)%,and (0.10)%,with (0.02)%,today (0.02)%


In [50]:
print(input_ids.shape)
next_token_probs = model(input_ids).logits[0, -1, :].softmax(dim=-1)
print(next_token_probs.shape)
sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
print(sorted_ids.shape)
print(sorted_ids[None, None, 0].shape)
token_id = next_token_probs.argmax(dim=-1)
print(token_id.shape)
print(torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1))
print(torch.cat([input_ids, sorted_ids[0, None, None]], dim=-1))

torch.Size([1, 4])
torch.Size([50257])
torch.Size([50257])
torch.Size([1, 1])
torch.Size([])
tensor([[41762,   364,   389,   262,   749]], device='cuda:0')
tensor([[41762,   364,   389,   262,   749]], device='cuda:0')
