In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from dataclasses import dataclass
from transformers import GPT2LMHeadModel, set_seed
import tiktoken

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to(device)
model.eval()

B = 5
max_length = 30

tokenizer = tiktoken.get_encoding('gpt2')
input_sentence = "Hello, I'm a language model,"
x = tokenizer.encode(input_sentence)
x = torch.tensor(x, dtype=torch.long, device=device)
x = x.unsqueeze(0).repeat(B, 1) # B, T

set_seed(42)
while x.shape[1] < max_length:
    with torch.no_grad():
        logits = model(x) # B, T, vocab_size
        logits = logits[:, -1, :]  # B, vocab_size
        probs = F.softmax(logits, dim=-1)  # logits != probs , softmax is needed
        # 50 is default in huggingface pipeline.  Clamp down rare probs.
        topk_probs, topk_ids = torch.topk(probs, 50, dim=-1) # (B, 50). 
        ix = torch.multinomial(topk_probs, 1) # sampling
        nx = torch.gather(topk_ids, -1, ix) # get data a/to dim and idx
        x = torch.cat((x, nx), dim=1) # B, T+1

lst = x.tolist()
print(input_sentence)
for idx in lst:
    print(">", tokenizer.decode(idx))