In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data.dataloader import DataLoader

import random_neural_net_models.mingpt.adder as adder
import random_neural_net_models.mingpt.model as gpt_model
import random_neural_net_models.mingpt.trainer as trainer
import random_neural_net_models.mingpt.utils as gpt_utils

In [None]:
data_config = adder.DataConfig(ndigit=2)

In [None]:
# construct train and test datasets
train_dataset = adder.AdditionDataset(data_config, split="train")
test_dataset = adder.AdditionDataset(data_config, split="test")

In [None]:
# get default config and overrides from the command line, if any
config = adder.get_config(
    vocab_size=train_dataset.get_vocab_size(),
    block_size=train_dataset.get_block_size(),
    max_iters=100,
)

print(config)

In [None]:
# gpt_utils.setup_logging(config)
gpt_utils.set_seed(config.system.seed)

In [None]:
model = gpt_model.GPT(config.model)

In [None]:
# construct the trainer object
trainer = trainer.Trainer(config.trainer, model, train_dataset)

In [None]:
# helper function for the evaluation of a model


def eval_split(trainer, split, max_batches=None):
    dataset = {"train": train_dataset, "test": test_dataset}[split]
    ndigit = config.data.ndigit
    results = []
    mistakes_printed_already = 0
    factors = torch.tensor([[10**i for i in range(ndigit + 1)][::-1]]).to(
        trainer.device
    )
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        # isolate the first two digits of the input sequence alone
        d1d2 = x[:, : ndigit * 2]
        # let the model sample the rest of the sequence
        d1d2d3 = model.generate(
            d1d2, ndigit + 1, do_sample=False
        )  # using greedy argmax, not sampling
        # isolate the last digit of the sampled sequence
        d3 = d1d2d3[:, -(ndigit + 1) :]
        d3 = d3.flip(1)  # reverse the digits to their "normal" order
        # decode the integers from individual digits
        d1i = (d1d2[:, :ndigit] * factors[:, 1:]).sum(1)
        d2i = (d1d2[:, ndigit : ndigit * 2] * factors[:, 1:]).sum(1)
        d3i_pred = (d3 * factors).sum(1)
        d3i_gt = d1i + d2i  # manually calculate the ground truth
        # evaluate the correctness of the results in this batch
        correct = (
            d3i_pred == d3i_gt
        ).cpu()  # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if (
                not correct[i] and mistakes_printed_already < 5
            ):  # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print(
                    "GPT claims that %d + %d = %d but gt is %d"
                    % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])
                )
        if max_batches is not None and b + 1 >= max_batches:
            break

    rt = torch.tensor(results, dtype=torch.float)
    print(
        "%s final score: %d/%d = %.2f%% correct"
        % (split, rt.sum(), len(results), 100 * rt.mean())
    )
    return rt.sum()

In [None]:
# iteration callback


def batch_end_callback(trainer):
    if trainer.iter_num % 10 == 0:
        print(
            f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}"
        )

    model.train()


trainer.set_callback("on_batch_end", batch_end_callback)

In [None]:
# run the optimization
trainer.run()

In [None]:
for x, y in test_dataset:
    print(f"x: {x}")
    print(f"y: {y}")
    pred = model.generate(x.unsqueeze(0), 3, do_sample=False)
    print(f"pred: {pred}")
    break