## Train GPT on addition

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(43)

In [3]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our
    encoding will simply be the n-digit first number, n-digit second number, 
    and (n+1)-digit result, all simply concatenated together. Because each addition
    problem is so structured, there is no need to bother the model with encoding
    +, =, or other tokens. Each possible sequence has the same length, and simply
    contains the raw digits of the addition problem.
    
    As a few examples, the 2-digit problems:
    - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]
    - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]
    etc.
    
    We will also only train GPT on the final (n+1)-digits because the first
    two n-digits are always assumed to be given. So when we give GPT an exam later,
    we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like
    to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]
    in 3 sequential steps.
    
    fun exercise: does it help if the result is asked to be produced in reverse order?
    """

    def __init__(self, ndigit, split):
        self.split = split # train/test
        self.ndigit = ndigit
        self.vocab_size = 10 # 10 possible digits 0..9
        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back
        self.block_size = ndigit + ndigit + ndigit + 1 - 1
        
        # split up all addition problems into either training data or test data
        num = (10**self.ndigit)**2 # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(num)
        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        nd = 10**self.ndigit
        a = idx // nd
        b = idx %  nd
        c = a + b
        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes "0325028" 
        dix = [int(s) for s in render] # convert each character to its token index
        # x will be input to GPT and y will be the associated expected outputs
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence
        y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero
        return x, y


In [5]:
# create a dataset for e.g. 2-digit addition
ndigit = 2
train_dataset = AdditionDataset(ndigit=ndigit, split='train')
test_dataset = AdditionDataset(ndigit=ndigit, split='test')

In [6]:
train_dataset[0] # sample a training instance just to see what one raw example looks like

(tensor([4, 7, 1, 7, 0, 6]), tensor([-100, -100, -100,    0,    6,    4]))

In [7]:
from mingpt.model import GPT, GPTConfig, GPT1Config

#98.80%

#99.07%
#99.50%
#embd_pdrop = 0.0, 

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128,
                 # embd_pdrop = 0.0, 
                 # resid_pdrop = 0.0, 
                 # attn_pdrop = 0.0,
                 )
model = GPT(mconf)

12/01/2021 14:18:12 - INFO - mingpt.model -   number of parameters: 4.001280e+05


In [8]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
epochs=300
tconf = TrainerConfig(max_epochs=epochs, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=epochs*len(train_dataset)*(ndigit+1),
                      num_workers=4)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 17: train loss 1.96685. lr 5.999848e-04: 100%|██████████| 18/18 [00:09<00:00,  1.88it/s]
12/01/2021 14:18:25 - INFO - mingpt.trainer -   test loss: 1.836349
epoch 2 iter 17: train loss 1.76151. lr 5.999367e-04: 100%|██████████| 18/18 [00:00<00:00, 28.43it/s]
12/01/2021 14:18:27 - INFO - mingpt.trainer -   test loss: 1.637177
epoch 3 iter 17: train loss 1.63700. lr 5.998557e-04: 100%|██████████| 18/18 [00:00<00:00, 24.70it/s]
12/01/2021 14:18:28 - INFO - mingpt.trainer -   test loss: 1.489183
epoch 4 iter 17: train loss 1.54155. lr 5.997417e-04: 100%|██████████| 18/18 [00:00<00:00, 27.08it/s]
12/01/2021 14:18:30 - INFO - mingpt.trainer -   test loss: 1.361760
epoch 5 iter 17: train loss 1.48659. lr 5.995950e-04: 100%|██████████| 18/18 [00:00<00:00, 26.34it/s]
12/01/2021 14:18:31 - INFO - mingpt.trainer -   test loss: 1.287460
epoch 6 iter 17: train loss 1.39373. lr 5.994153e-04: 100%|██████████| 18/18 [00:00<00:00, 26.64it/s]
12/01/2021 14:18:32 - INFO - mingpt.trainer -   

epoch 47 iter 17: train loss 0.24275. lr 5.644385e-04: 100%|██████████| 18/18 [00:00<00:00, 25.90it/s]
12/01/2021 14:19:30 - INFO - mingpt.trainer -   test loss: 0.028594
epoch 48 iter 17: train loss 0.22470. lr 5.629402e-04: 100%|██████████| 18/18 [00:00<00:00, 26.86it/s]
12/01/2021 14:19:31 - INFO - mingpt.trainer -   test loss: 0.023940
epoch 49 iter 17: train loss 0.24933. lr 5.614131e-04: 100%|██████████| 18/18 [00:00<00:00, 26.59it/s]
12/01/2021 14:19:33 - INFO - mingpt.trainer -   test loss: 0.023341
epoch 50 iter 17: train loss 0.22519. lr 5.598573e-04: 100%|██████████| 18/18 [00:00<00:00, 26.80it/s]
12/01/2021 14:19:34 - INFO - mingpt.trainer -   test loss: 0.019271
epoch 51 iter 17: train loss 0.23255. lr 5.582729e-04: 100%|██████████| 18/18 [00:00<00:00, 25.47it/s]
12/01/2021 14:19:35 - INFO - mingpt.trainer -   test loss: 0.024150
epoch 52 iter 17: train loss 0.21736. lr 5.566603e-04: 100%|██████████| 18/18 [00:00<00:00, 27.36it/s]
12/01/2021 14:19:37 - INFO - mingpt.traine

epoch 95 iter 17: train loss 0.12762. lr 4.634600e-04: 100%|██████████| 18/18 [00:00<00:00, 26.73it/s]
12/01/2021 14:20:37 - INFO - mingpt.trainer -   test loss: 0.002380
epoch 96 iter 17: train loss 0.13238. lr 4.608164e-04: 100%|██████████| 18/18 [00:00<00:00, 25.96it/s]
12/01/2021 14:20:39 - INFO - mingpt.trainer -   test loss: 0.001953
epoch 97 iter 17: train loss 0.14967. lr 4.581553e-04: 100%|██████████| 18/18 [00:00<00:00, 27.23it/s]
12/01/2021 14:20:40 - INFO - mingpt.trainer -   test loss: 0.002322
epoch 98 iter 17: train loss 0.13206. lr 4.554767e-04: 100%|██████████| 18/18 [00:00<00:00, 25.83it/s]
12/01/2021 14:20:41 - INFO - mingpt.trainer -   test loss: 0.002619
epoch 99 iter 17: train loss 0.12263. lr 4.527811e-04: 100%|██████████| 18/18 [00:00<00:00, 26.28it/s]
12/01/2021 14:20:43 - INFO - mingpt.trainer -   test loss: 0.002383
epoch 100 iter 17: train loss 0.12614. lr 4.500688e-04: 100%|██████████| 18/18 [00:00<00:00, 26.11it/s]
12/01/2021 14:20:44 - INFO - mingpt.train

epoch 143 iter 17: train loss 0.07530. lr 3.220337e-04: 100%|██████████| 18/18 [00:00<00:00, 26.65it/s]
12/01/2021 14:21:44 - INFO - mingpt.trainer -   test loss: 0.001086
epoch 144 iter 17: train loss 0.10555. lr 3.188990e-04: 100%|██████████| 18/18 [00:00<00:00, 26.70it/s]
12/01/2021 14:21:45 - INFO - mingpt.trainer -   test loss: 0.000876
epoch 145 iter 17: train loss 0.08584. lr 3.157623e-04: 100%|██████████| 18/18 [00:00<00:00, 22.88it/s]
12/01/2021 14:21:47 - INFO - mingpt.trainer -   test loss: 0.001008
epoch 146 iter 17: train loss 0.08591. lr 3.126238e-04: 100%|██████████| 18/18 [00:00<00:00, 26.99it/s]
12/01/2021 14:21:48 - INFO - mingpt.trainer -   test loss: 0.000870
epoch 147 iter 17: train loss 0.08928. lr 3.094840e-04: 100%|██████████| 18/18 [00:00<00:00, 26.58it/s]
12/01/2021 14:21:49 - INFO - mingpt.trainer -   test loss: 0.000764
epoch 148 iter 17: train loss 0.07990. lr 3.063431e-04: 100%|██████████| 18/18 [00:00<00:00, 26.68it/s]
12/01/2021 14:21:51 - INFO - mingpt.

epoch 191 iter 17: train loss 0.07119. lr 1.751551e-04: 100%|██████████| 18/18 [00:00<00:00, 26.99it/s]
12/01/2021 14:22:50 - INFO - mingpt.trainer -   test loss: 0.000328
epoch 192 iter 17: train loss 0.05852. lr 1.723050e-04: 100%|██████████| 18/18 [00:00<00:00, 26.14it/s]
12/01/2021 14:22:52 - INFO - mingpt.trainer -   test loss: 0.000310
epoch 193 iter 17: train loss 0.06385. lr 1.694689e-04: 100%|██████████| 18/18 [00:00<00:00, 25.83it/s]
12/01/2021 14:22:53 - INFO - mingpt.trainer -   test loss: 0.000370
epoch 194 iter 17: train loss 0.05569. lr 1.666472e-04: 100%|██████████| 18/18 [00:00<00:00, 27.14it/s]
12/01/2021 14:22:55 - INFO - mingpt.trainer -   test loss: 0.000378
epoch 195 iter 17: train loss 0.08377. lr 1.638400e-04: 100%|██████████| 18/18 [00:00<00:00, 26.52it/s]
12/01/2021 14:22:56 - INFO - mingpt.trainer -   test loss: 0.000313
epoch 196 iter 17: train loss 0.07272. lr 1.610478e-04: 100%|██████████| 18/18 [00:00<00:00, 24.33it/s]
12/01/2021 14:22:57 - INFO - mingpt.

epoch 239 iter 17: train loss 0.05894. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 24.65it/s]
12/01/2021 14:23:58 - INFO - mingpt.trainer -   test loss: 0.000190
epoch 240 iter 17: train loss 0.05932. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.69it/s]
12/01/2021 14:23:59 - INFO - mingpt.trainer -   test loss: 0.000198
epoch 241 iter 17: train loss 0.06459. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.31it/s]
12/01/2021 14:24:01 - INFO - mingpt.trainer -   test loss: 0.000194
epoch 242 iter 17: train loss 0.05445. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.17it/s]
12/01/2021 14:24:02 - INFO - mingpt.trainer -   test loss: 0.000205
epoch 243 iter 17: train loss 0.06977. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.06it/s]
12/01/2021 14:24:03 - INFO - mingpt.trainer -   test loss: 0.000175
epoch 244 iter 17: train loss 0.05742. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 27.07it/s]
12/01/2021 14:24:05 - INFO - mingpt.

epoch 287 iter 17: train loss 0.06108. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.51it/s]
12/01/2021 14:25:05 - INFO - mingpt.trainer -   test loss: 0.000168
epoch 288 iter 17: train loss 0.05117. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.68it/s]
12/01/2021 14:25:07 - INFO - mingpt.trainer -   test loss: 0.000158
epoch 289 iter 17: train loss 0.05335. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 24.20it/s]
12/01/2021 14:25:08 - INFO - mingpt.trainer -   test loss: 0.000174
epoch 290 iter 17: train loss 0.07408. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 26.35it/s]
12/01/2021 14:25:10 - INFO - mingpt.trainer -   test loss: 0.000172
epoch 291 iter 17: train loss 0.06583. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 25.92it/s]
12/01/2021 14:25:11 - INFO - mingpt.trainer -   test loss: 0.000166
epoch 292 iter 17: train loss 0.05550. lr 6.000000e-05: 100%|██████████| 18/18 [00:00<00:00, 27.02it/s]
12/01/2021 14:25:12 - INFO - mingpt.

In [9]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample

def give_exam(dataset, batch_size=32, max_batches=-1):
    
    results = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        d1d2 = x[:, :ndigit*2]
        d1d2d3 = sample(model, d1d2, ndigit+1)
        d3 = d1d2d3[:, -(ndigit+1):]
        factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
        # decode the integers from individual digits
        d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
        d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
        d3i_pred = (d3 * factors).sum(1)
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            judge = 'YEP!!!' if correct[i] else 'NOPE'
            if not correct[i]:
                print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)" 
                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))
        
        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

In [10]:
# training set: how well did we memorize?
give_exam(train_dataset, batch_size=1024, max_batches=10)

final score: 9000/9000 = 100.00% correct


In [11]:
# test set: how well did we generalize?
give_exam(test_dataset, batch_size=1024, max_batches=-1)

final score: 1000/1000 = 100.00% correct


In [12]:
# well that's amusing... our model learned everything except 55 + 45