In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


In [3]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [4]:
device

'mps'

In [5]:
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

In [9]:
import pandas as pd
text = "Transformers are the "
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
iterations = []
num_steps = 8
choices_per_step = 5

with torch.no_grad():
    for step in range(num_steps):
        iteration = dict()
        iteration["input"] = tokenizer.decode(input_ids[0])
        output = model(input_ids)[0]
        # Select logits of the first batch and the last token
        logits = output[0, -1, :]
        probs = torch.softmax(logits, dim=-1)
        sorted_ids = torch.argsort(probs, descending=True)
        for choice_idx in range(choices_per_step):
            token_id = sorted_ids[choice_idx]
            token_prob = probs[token_id].cpu().numpy()
            token_choice = tokenizer.decode(token_id) + f" ({token_prob:.2f})"
            iteration[f"choice_{choice_idx}"] = token_choice
        input_ids = torch.cat((input_ids, sorted_ids[None, 0, None]), dim=1)
        iterations.append(iteration)
pd.DataFrame(iterations)


Unnamed: 0,input,choice_0,choice_1,choice_2,choice_3,choice_4
0,Transformers are the,(0.15),ills (0.15),________ (0.07),icky (0.05),_____ (0.05)
1,Transformers are the,most (0.07),ultimate (0.05),original (0.02),""" (0.02)",main (0.02)
2,Transformers are the most,popular (0.17),common (0.05),powerful (0.05),famous (0.04),successful (0.03)
3,Transformers are the most popular,toy (0.10),toys (0.07),Transformers (0.06),of (0.06),and (0.05)
4,Transformers are the most popular toy,line (0.50),in (0.10),of (0.08),lines (0.05),line (0.04)
5,Transformers are the most popular toy line,in (0.37),of (0.14),", (0.06)",on (0.04),. (0.03)
6,Transformers are the most popular toy line in,the (0.71),history (0.08),America (0.04),Japan (0.02),all (0.01)
7,Transformers are the most popular toy line in...,world (0.68),US (0.05),history (0.03),universe (0.03),United (0.03)
