In [31]:
from torch_impl.model import GPT
from dataset import SortDataset
from torch_impl.utils import ModelConfig, TrainConfig, TrainerCallbackEvent, set_seed
from torch_impl.train import Trainer

import torch
from torch.utils.data.dataloader import DataLoader

set_seed(3407)

%reload_ext autoreload
%autoreload 2

In [32]:
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')

In [33]:
x, y = train_dataset[1]

print(x)
print(y)

tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
tensor([-1, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1])


In [34]:
model_config = ModelConfig(
    vocab_size=train_dataset.get_vocab_size(), 
    context_window=train_dataset.get_block_size(), 
    model_type='gpt-nano')

model_config

ModelConfig(vocab_size=3, context_window=11, n_embd=48, n_head=3, n_layer=3, model_type='gpt-nano', attn_pdrop=0.1, recid_pdrop=0.1, embd_pdrop=0.1)

In [35]:
model = GPT(model_config)

number of parameters: 0.086M


In [36]:
train_config = TrainConfig()
train_config.learning_rate = 5e-4
train_config.max_iters = 2000

print(train_config)

trainer = Trainer(train_config, model, train_dataset, test_dataset)

TrainConfig(device='auto', num_workers=0, batch_size=64, sequence_len=6, max_iters=2000, learning_rate=0.0005, betas=(0.9, 0.95), weight_decay=0.1, grad_norm_clip=1.0)
model is running on cuda


In [37]:
def print_batch_loss(trainer: Trainer):
    if trainer.n_iter % 100 == 0:
        print(f'{trainer.n_iter=}, {trainer.loss.item()=:.4f}')

trainer.add_callback(TrainerCallbackEvent.on_train_batch_end, print_batch_loss)

In [38]:
trainer.run()

trainer.n_iter=0, trainer.loss.item()=1.0576
trainer.n_iter=100, trainer.loss.item()=0.1713
trainer.n_iter=200, trainer.loss.item()=0.0806
trainer.n_iter=300, trainer.loss.item()=0.0367
trainer.n_iter=400, trainer.loss.item()=0.0446
trainer.n_iter=500, trainer.loss.item()=0.0231
trainer.n_iter=600, trainer.loss.item()=0.0408
trainer.n_iter=700, trainer.loss.item()=0.0171
trainer.n_iter=800, trainer.loss.item()=0.0090
trainer.n_iter=900, trainer.loss.item()=0.0274
trainer.n_iter=1000, trainer.loss.item()=0.0282
trainer.n_iter=1100, trainer.loss.item()=0.0138
trainer.n_iter=1200, trainer.loss.item()=0.0039
trainer.n_iter=1300, trainer.loss.item()=0.0037
trainer.n_iter=1400, trainer.loss.item()=0.0024
trainer.n_iter=1500, trainer.loss.item()=0.0060
trainer.n_iter=1600, trainer.loss.item()=0.0081
trainer.n_iter=1700, trainer.loss.item()=0.0003
trainer.n_iter=1800, trainer.loss.item()=0.0029
trainer.n_iter=1900, trainer.loss.item()=0.0021
trainer.n_iter=2000, trainer.loss.item()=0.0078


In [39]:
trainer.eval('train')
trainer.eval('test')

train final score: 3200.0/3200 = 100.00% correct
test final score: 5000.0/5000 = 100.00% correct


In [40]:
# run random given sequence through the model as well
n = train_dataset.sequence_length
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 0, 2, 1, 0, 1]]
predicted sorted: [[0, 0, 0, 1, 1, 2]]
gt sort         : [0, 0, 0, 1, 1, 2]
matches         : True
