In [1]:
import torch

from src.model import GPTLanguageModel
from src.utils import decode, estimate_loss, get_batch, get_prompt

In [2]:
train_data = torch.load("assets/train.pt")
valid_data = torch.load("assets/valid.pt")

vocab = open("assets/vocab.txt", "r").read()
vocab_size = len(vocab)

In [3]:
# learning hyperparameters
learn_rate = 3e-4
iters = 5
eval_interval = 500

# initialize model & optimizer
model = GPTLanguageModel(vocab_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate)

# number of model parameters
n_params = sum(p.numel() for p in model.parameters())
n_params

316666

In [None]:
# learning iterations
for i in range(iters):

    # every once in a while evaluate the loss on train and val sets
    if i % eval_interval == 0 or i == iters - 1:
        train_loss = estimate_loss(model, train_data)
        valid_loss = estimate_loss(model, valid_data)
        print(f"step {i}: train loss {train_loss:.4f}, valid loss {valid_loss:.4f}")

    # sample batch of data
    x_batch, y_batch = get_batch(train_data)

    # evaluate the loss
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:

prompt = get_prompt(vocab)

In [None]:
# generate from the model
sampled = model.generate(prompt, max_new_tokens=100)
print(decode(sampled, vocab))