In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import GPT

In [2]:
# load toy story script as text into a list
text: list[str] = []
with open(f"shakespeare.txt") as f:
    text = list(f.read())

In [3]:
# look up table encoder decoder from letters in words to numbers
def make_encoder_decoder(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
    letters = sorted(set("".join(words)))
    encoder = {letter: i for i, letter in enumerate(letters)}
    decoder = {i: letter for i, letter in enumerate(letters)}
    return encoder, decoder

In [4]:
encoder, decoder = make_encoder_decoder(text)

In [None]:
device = "mps"
seq_len = 128
model = GPT(len(encoder)).to(device)
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])

In [15]:
# generate a random sequence of letters
@torch.no_grad()
def generate_text(model, start_text, max_len=200):
    model.eval()
    text = start_text
    for i in range(max_len):
        x = torch.tensor([encoder[letter] for letter in text[-seq_len:]], dtype=torch.long).to("mps")
        x = x.unsqueeze(0)
        logits = model(x)
        logits = logits[:, -1, :]
        # sample from the distribution
        probs = F.softmax(logits, dim=-1)
        letter = torch.multinomial(probs, 1).squeeze(0)[-1]
        text += decoder[letter.item()]

    return text

In [None]:
# generate a random sequence of letters
print(generate_text(model, "I am a toy", max_len=200))