In [12]:
from model import GPT
from dataset import SortDataset
from utils import ModelConfig, TrainConfig, TrainerCallbackEvent, set_seed
from train import Trainer

import torch
from torch.utils.data.dataloader import DataLoader

set_seed(3407)

%reload_ext autoreload
%autoreload 2

[autoreload of pkg_resources failed: Traceback (most recent call last):
  File "/Users/lukakuma/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 257, in check
    superreload(m, reload, self.old_objects)
  File "/Users/lukakuma/miniconda3/envs/ml/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 455, in superreload
    module = reload(module)
  File "/Users/lukakuma/miniconda3/envs/ml/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/Users/lukakuma/miniconda3/envs/ml/lib/python3.10/site-packages/pkg_resources/__init__.py", line 3095, in <module>
    class RequirementParseError(packaging.requirements.InvalidRequirement):
AttributeError: module 'pkg_resources._vendor.pac

In [13]:
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')

In [14]:
x, y = train_dataset[1]

print(x)
print(y)

tensor([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1])
tensor([-1, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1])


In [15]:
model_config = ModelConfig(
    vocab_size=train_dataset.get_vocab_size(), 
    context_window=train_dataset.get_block_size(), 
    model_type='gpt-nano')

model_config

ModelConfig(vocab_size=3, context_window=11, n_embd=48, n_head=3, n_layer=3, model_type='gpt-nano', attn_pdrop=0.1, recid_pdrop=0.1, embd_pdrop=0.1)

In [16]:
model = GPT(model_config)

number of parameters: 0.086M


In [17]:
train_config = TrainConfig()
train_config.learning_rate = 5e-4
train_config.max_iters = 2000

print(train_config)

trainer = Trainer(train_config, model, train_dataset, test_dataset)

TrainConfig(device='auto', num_workers=0, batch_size=64, max_iters=2000, learning_rate=0.0005, betas=(0.9, 0.95), weight_decay=0.1, grad_norm_clip=1.0)
model is running on cpu


In [18]:
def print_batch_loss(trainer: Trainer):
    if trainer.n_iter % 100 == 0:
        print(f'{trainer.n_iter=}, {trainer.loss.item()=:.4f}')

trainer.add_callback(TrainerCallbackEvent.on_train_batch_end, print_batch_loss)

In [19]:
trainer.run()

trainer.n_iter=0, trainer.loss.item()=1.0634
trainer.n_iter=100, trainer.loss.item()=0.1920
trainer.n_iter=200, trainer.loss.item()=0.0838
trainer.n_iter=300, trainer.loss.item()=0.0350
trainer.n_iter=400, trainer.loss.item()=0.0286
trainer.n_iter=500, trainer.loss.item()=0.0474
trainer.n_iter=600, trainer.loss.item()=0.0090
trainer.n_iter=700, trainer.loss.item()=0.0120
trainer.n_iter=800, trainer.loss.item()=0.0353
trainer.n_iter=900, trainer.loss.item()=0.0920
trainer.n_iter=1000, trainer.loss.item()=0.0280
trainer.n_iter=1100, trainer.loss.item()=0.0114
trainer.n_iter=1200, trainer.loss.item()=0.0229
trainer.n_iter=1300, trainer.loss.item()=0.0044
trainer.n_iter=1400, trainer.loss.item()=0.0256
trainer.n_iter=1500, trainer.loss.item()=0.0102
trainer.n_iter=1600, trainer.loss.item()=0.0009
trainer.n_iter=1700, trainer.loss.item()=0.0005
trainer.n_iter=1800, trainer.loss.item()=0.0291
trainer.n_iter=1900, trainer.loss.item()=0.0131
trainer.n_iter=2000, trainer.loss.item()=0.0008


In [20]:
trainer.eval('train')
trainer.eval('test')

train final score: 3200.0/3200 = 100.00% correct
test final score: 5000.0/5000 = 100.00% correct


In [21]:
# run random given sequence through the model as well
n = train_dataset.sequence_length
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 0, 2, 1, 0, 1]]
predicted sorted: [[0, 0, 0, 1, 1, 2]]
gt sort         : [0, 0, 0, 1, 1, 2]
matches         : True
