based on https://github.com/karpathy/minGPT/blob/master/demo.ipynb

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data.dataloader import DataLoader

import random_neural_net_models.mingpt.data as gpt_data
import random_neural_net_models.mingpt.model as gpt_model
import random_neural_net_models.mingpt.trainer as gpt_trainer
import random_neural_net_models.mingpt.utils as gpt_utils

gpt_utils.set_seed(3407)

In [None]:
# print an example instance of the dataset
train_dataset = gpt_data.SortDataset("train")
test_dataset = gpt_data.SortDataset("test")
x, y = train_dataset[0]
for a, b in zip(x, y):
    print(int(a), int(b))

In [None]:
model_config = gpt_model.GPT.get_config(
    model_type="gpt-nano",
    vocab_size=train_dataset.get_vocab_size(),
    block_size=train_dataset.get_block_size(),
)
model = gpt_model.GPT(model_config)

In [None]:
train_config = gpt_trainer.Trainer.get_config(
    learning_rate=5e-4,  # the model we're using is so small that we can go a bit faster
    max_iters=2000,
    num_workers=0,
)

trainer = gpt_trainer.Trainer(train_config, model, train_dataset)

In [None]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(
            f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}"
        )


trainer.set_callback("on_batch_end", batch_end_callback)

trainer.run()

In [None]:
# now let's perform some evaluation
model.eval();

In [None]:
def eval_split(trainer, split, max_batches):
    dataset = {"train": train_dataset, "test": test_dataset}[split]
    n = train_dataset.length  # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate(
            inp, n, do_sample=False
        )  # using greedy argmax, not sampling
        sol_candidate = cat[:, n:]  # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (
            (sol == sol_candidate).all(1).cpu()
        )  # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if (
                not correct[i] and mistakes_printed_already < 3
            ):  # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print(
                    "GPT claims that %s sorted is %s but gt is %s"
                    % (
                        inp[i].tolist(),
                        sol_candidate[i].tolist(),
                        sol[i].tolist(),
                    )
                )
        if max_batches is not None and b + 1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print(
        "%s final score: %d/%d = %.2f%% correct"
        % (split, rt.sum(), len(results), 100 * rt.mean())
    )
    return rt.sum()


# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, "train", max_batches=50)
    test_score = eval_split(trainer, "test", max_batches=50)

In [None]:
# let's run a random given sequence through the model as well
n = train_dataset.length  # naugy direct access shrug
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print("input sequence  :", inp.tolist())
print("predicted sorted:", sol_candidate.tolist())
print("gt sort         :", sol.tolist())
print("matches         :", bool((sol == sol_candidate).all()))