In [1]:
import torch
import torch.nn as nn
from model import TinyTransformer
from dataset import get_batch_split, vocab_size, stoi, decode
import math

In [2]:
device = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

Using device: mps


In [3]:
model = TinyTransformer(vocab_size, n_embd=128, block_size=128, n_heads=4, n_layers=4).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

batch_size = 64
steps = 3000
eval_interval = 500
eval_steps = 50

@torch.no_grad()
def estimate_loss():
    model.eval()
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_steps)
        for k in range(eval_steps):
            x, y = get_batch_split(split, block_size=128, batch_size=batch_size, device=device)
            logits = model(x)
            B, T, C = logits.shape
            loss = loss_fn(logits.view(B*T, C), y.view(B*T))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

for step in range(steps):
    x, y = get_batch_split('train', block_size=128, batch_size=batch_size, device=device)

    logits = model(x)
    B, T, C = logits.shape
    loss = loss_fn(logits.view(B*T, C), y.view(B*T))

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
    optimizer.step()
    
    if step % eval_interval == 0 or step == steps - 1:
        losses = estimate_loss()
        print(f"Step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

torch.save(model.state_dict(), 'tiny_shakespeare_model.pth')
print("Model saved to tiny_shakespeare_model.pth")


Step 0: train loss 3.8977, val loss 3.9050
Step 500: train loss 2.0654, val loss 2.1082
Step 1000: train loss 1.7457, val loss 1.8718
Step 1500: train loss 1.5881, val loss 1.7663
Step 2000: train loss 1.4858, val loss 1.6738
Step 2500: train loss 1.4286, val loss 1.6302
Step 2999: train loss 1.3857, val loss 1.6079
Model saved to tiny_shakespeare_model.pth


In [8]:
def generate(model, start, length=100, temperature=0.8, top_k=40):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([stoi[s] for s in start], dtype=torch.long).unsqueeze(0).to(device)
        idx = model.generate(x, max_new_tokens=length, temperature=temperature, top_k=top_k)
        return decode(idx[0].tolist())

print(generate(model, "ROM", length=200))

ROMEO:
Sir, in the state, unto like again, hast
With he soldiershood by the business in the oes,
His presence their fearful and art, good stop,
Still Edward not blood death.

HERMIONE:
Have denianted the
