In [1]:
import torch
from model import GPTLanguageModel
from preprocess import decode

In [2]:
learn_rate = 3e-4
max_iters = 10000
eval_interval = 500
eval_iters = 200
block_size = 128 # TODO: Align in config
batch_size = 4

train_data = torch.load("assets/train.pt")
valid_data = torch.load("assets/valid.pt")

In [3]:
# initialize model & optimizer
model = GPTLanguageModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate)

# number of model parameters
n_params = sum(p.numel() for p in model.parameters())
n_params

71425

In [4]:
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()

    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, block_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    model.train()
    return out


def get_batch(split, block_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else valid_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [5]:
# learning iterations
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample batch of data
    x_batch, y_batch = get_batch('train', block_size=32)

    # evaluate the loss
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2051, val loss 4.2039
step 500: train loss 2.7864, val loss 2.7987
step 1000: train loss 2.6321, val loss 2.6635
step 1500: train loss 2.5605, val loss 2.5786
step 2000: train loss 2.5336, val loss 2.5478
step 2500: train loss 2.5367, val loss 2.5738
step 3000: train loss 2.5517, val loss 2.6023
step 3500: train loss 2.5854, val loss 2.6409
step 4000: train loss 2.5938, val loss 2.6472
step 4500: train loss 2.6030, val loss 2.6477
step 5000: train loss 2.6316, val loss 2.6904
step 5500: train loss 2.6532, val loss 2.7012
step 6000: train loss 2.6225, val loss 2.7201
step 6500: train loss 2.6593, val loss 2.7262
step 7000: train loss 2.6827, val loss 2.7636
step 7500: train loss 2.7191, val loss 2.7944
step 8000: train loss 2.7411, val loss 2.7973
step 8500: train loss 2.7755, val loss 2.8264
step 9000: train loss 2.7687, val loss 2.8230
step 9500: train loss 2.7786, val loss 2.8505
step 9999: train loss 2.8231, val loss 2.9023


In [6]:
# generate from the model
from preprocess import get_vocab

text = open("assets/input.txt", "r").read()
vocab = get_vocab(text)

context = torch.zeros((1, 1), dtype=torch.long)
sampled = model.generate(context, max_new_tokens=500)[0]
print(decode(sampled, vocab))


CAy Clak; mow, ar it a Moke, houshead your she be Of h'd t Hyound guve ave mud BeumPeruber ITin My athe s tw -ld, UmyINord w Le be Terneor o ge Bumy he Ae have outhe heAeensute Prsense Phf Foususlo I wno ICo;
Shthilurseeshaike wicos I es Kize marre Gous Rd ne Henomnt ss Rnfind otare y Mnulint 's ine aciny k'kpene, ondswaree bin t Isuloohaky
Bueempamearend.
Th Maimy-
Y o
O ws mane fowe s My He ch h h Bucen Lore r Ve Iff htreno Wh t it f Go os Ore iot my KIs Lousthe Is beesise Gr Eo I'shacoulisono
