A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute.

In [1]:
import torch
from mingpt.utils import set_seed
from torch.utils.data.dataloader import DataLoader
from projects.sort.sort import SortDataset
set_seed(3407)

In [2]:
# print an example instance of the dataset
train_dataset = SortDataset("train")
test_dataset = SortDataset("test")
x, y, x_l, y_l = train_dataset[0]
for a, b, c, d in zip(x, y, x_l, y_l):
    print(int(a), int(b), "\t", int(c), int(d))

1 -1 	 1 -1
0 -1 	 0 -1
1 -1 	 1 -1
0 -1 	 0 -1
0 -1 	 0 -1
0 0 	 0 1
0 0 	 1 0
0 0 	 0 1
0 0 	 1 0
0 1 	 0 0
1 1 	 0 0


In [3]:
from mingpt.model import GPT

model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config)

number of parameters: 0.09M


In [4]:
# create a Trainer object
from mingpt.trainer import Trainer

def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 500
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)
trainer.set_callback('on_batch_end', batch_end_callback)

running on device cuda


In [5]:
train_config.dpo_loss = True
train_config.dpo_alpha = 10
train_config.dpo_beta = 0.01

In [6]:
trainer.run()

iter_dt 0.00ms; iter 0: train loss 0.69384
iter_dt 31.69ms; iter 100: train loss 0.46285
iter_dt 28.75ms; iter 200: train loss 0.45303
iter_dt 39.65ms; iter 300: train loss 0.43603
iter_dt 26.68ms; iter 400: train loss 0.42619


In [7]:
# now let's perform some evaluation
_ = model.eval()

In [8]:
def eval_split(trainer, split, max_batches):
    dataset = {"train": train_dataset, "test": test_dataset}[split]
    n = train_dataset.length  # naugy direct access shrug
    results = []
    log_probs_w = torch.tensor([]).to(trainer.device)
    log_probs_l = torch.tensor([]).to(trainer.device)
    log_prob_diff = torch.tensor([]).to(trainer.device)
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y, x_l, y_l) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        x_l = x_l.to(trainer.device)
        y_l = y_l.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate(
            idx=inp, max_new_tokens=n, do_sample=True
        )  # using greedy argmax, not sampling
        sol_candidate = cat[:, n:]  # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (
            (sol == sol_candidate).all(1).cpu()
        )  # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if (
                not correct[i] and mistakes_printed_already < 3
            ):  # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print(
                    "GPT claims that %s sorted is %s but gt is %s"
                    % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist())
                )

        with torch.no_grad():
            lp_w = model.log_prob(x, y)
            lp_l = model.log_prob(x_l, y_l)
            log_probs_w = torch.concat((log_probs_w, lp_w), dim=-1)
            log_probs_l = torch.concat((log_probs_l, lp_l), dim=-1)
            log_prob_diff = torch.concat((log_prob_diff, lp_w - lp_l), dim=-1)
        if max_batches is not None and b + 1 >= max_batches:
            break

    rt = torch.tensor(results, dtype=torch.float)
    print(
        "%s final score: %d/%d = %.2f%% correct"
        % (split, rt.sum(), len(results), 100 * rt.mean())
    )
    print(
        f"prob w: {round(float(torch.exp(log_probs_w).mean()), 4)}"
    )
    print(
        f"prob l: {round(float(torch.exp(log_probs_l).mean()), 4)}"
    )
    print(
        f"avg log prob diff (w - l): {round(float(log_prob_diff.mean()), 4)}"
    )
    return rt.sum()


# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, "train", max_batches=50)
    test_score = eval_split(trainer, "test", max_batches=50)

GPT claims that [0, 2, 2, 1, 2, 2] sorted is [0, 1, 2, 2, 1, 2] but gt is [0, 1, 2, 2, 2, 2]
GPT claims that [1, 2, 1, 0, 1, 1] sorted is [0, 0, 1, 1, 1, 2] but gt is [0, 1, 1, 1, 1, 2]
GPT claims that [1, 0, 1, 1, 0, 1] sorted is [0, 0, 0, 1, 1, 1] but gt is [0, 0, 1, 1, 1, 1]
train final score: 4950/5000 = 99.00% correct
prob w: 0.9883
prob l: 0.0171
avg log prob diff (w - l): 29.0978
GPT claims that [1, 1, 1, 0, 0, 0] sorted is [0, 0, 0, 0, 1, 1] but gt is [0, 0, 0, 1, 1, 1]
GPT claims that [1, 2, 1, 2, 2, 1] sorted is [1, 1, 2, 2, 2, 2] but gt is [1, 1, 1, 2, 2, 2]
GPT claims that [0, 2, 2, 0, 0, 2] sorted is [2, 0, 0, 2, 2, 2] but gt is [0, 0, 0, 2, 2, 2]
test final score: 4943/5000 = 98.86% correct
prob w: 0.9869
prob l: 0.0208
avg log prob diff (w - l): 28.4964


In [9]:
# let's run a random given sequence through the model as well
n = train_dataset.length # naugy direct access shrug
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate(idx=inp, max_new_tokens=n, do_sample=True)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 0, 2, 1, 0, 1]]
predicted sorted: [[0, 0, 0, 1, 1, 2]]
gt sort         : [0, 0, 0, 1, 1, 2]
matches         : True
