## Train GPT on addition

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [203]:
import numpy as np
import torch
import string
import os
from tqdm.auto import tqdm
import torch.nn as nn
from torch.nn import functional as F
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [204]:
test = list(range(8))

In [212]:
if test[2] not in test[2:]:
    print('Hey')
test[2:-1] + test[2:-1]

[2, 3, 4, 5, 6, 2, 3, 4, 5, 6]

In [175]:
from torch.utils.data import Dataset

class MathDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    """

    def __init__(self, fname, MD):
        self.CD = MD
        self.mem_slots = MD.mem_slots
        self.dataset = self.MD.prepare_data(fname) # Extract data, source, memory, and target        
        self.ixes = np.array(list(range(len(self.dataset)))) 
        
    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        
        xy = []
        for i in range(self.mem_slots + 3):
            xy.append(self.dataset[i][idx])
            
        src, mem, trg = self.MD.create_x_y_pair(xy)
        src_mem_trg = self.MD.list2tokens(src + mem + trg)
        x = self.MD.x2Canvas(src_mem_trg)
        y = self.MD.y2Canvas(src_mem_trg)
        y = self.MD.mask_padding(y)

        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long) 
        y = self.MD.mask_question_memory(y, src, mem) # we will only train in the output locations. -100 will mask loss to zero
        
        return x, y


In [193]:
# create a dataset 
easy = 'run/numbers__place_value.txt'
medium = 'run/numbers__is_prime.txt'
hard = 'run/numbers__list_prime_factors.txt'

In [194]:
!rm -rf run
!cp -r data run

In [195]:
from mingpt.md import MemData
memory_slots = 7
MD = MemData(memory_slots)
MD.initiate_mem_slot_data(hard)

In [196]:
# create a dataset 
easy_test = 'run/test_numbers__place_value.txt'
medium_test = 'run/test_numbers__is_prime.txt'
hard_test = 'run/test_numbers__list_prime_factors.txt'

In [197]:
test_dataset = MathDataset(fname=hard_test, CD=CD)

In [199]:
#test_dataset[23]

In [200]:
CD.block_size

283

In [84]:
# for i in range(0, len(train_dataset)):
#     if len(train_dataset[i][0]) != 52 or len(train_dataset[i][1]) != 52:
#         print(train_dataset.block_size)
#         print(len(train_dataset[i][0]))
#         print(len(train_dataset[i][1]))
#         print(train_dataset[i])

train_dataset[0] # sample a training instance just to see what one raw example looks like

(tensor([68, 79, 72, 91,  3, 80, 90,  3, 91, 79, 76,  3, 84, 80, 83, 83, 80, 86,
         85, 90,  3, 75, 80, 78, 80, 91,  3, 86, 77,  3, 37, 45, 44, 40, 43, 41,
         44, 39, 24,  1, 45,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   45,    2, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]))

In [6]:
from mingpt.model import GPT, GPTConfig, GPT1Config

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
model = GPT(mconf)

11/19/2020 10:03:05 - INFO - mingpt.model -   number of parameters: 4.285440e+05


In [7]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(14+1),
                      num_workers=0)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 1171: train loss 0.02204. lr 5.999895e-04: 100%|██████████| 1172/1172 [37:05<00:00,  1.90s/it]
11/19/2020 10:41:20 - INFO - mingpt.trainer -   test loss: 0.003768
11/19/2020 10:41:20 - INFO - mingpt.trainer -   saving model.pth


In [87]:
trainer.save_checkpoint()

11/18/2020 06:13:04 - INFO - mingpt.trainer -   saving model.pth


In [183]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample, Tokenizer

def give_exam(dataset, batch_size=1, max_batch_size=512, max_batches=-1):
    
    t = Tokenizer(dataset)
    results, examples = [], []
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    prev_src_len, predict, batch, x_in = 0, 0, 0, 0
    pbar = tqdm(enumerate(loader), total=len(loader))
    
    for b, (x, y) in pbar:
        
        src_len = t.locateToken('answer', x[0]) 
        x_in = leftover if prev_src_len == -1 else x_in
        
        # Concat input source with same length
        if prev_src_len == src_len:
            x_in = torch.cat((x_in, x), 0)
        elif prev_src_len == 0:
            x_in = x
        else:
            prev_src_len = -1
            predict = 1
            leftover = x
        prev_src_len = src_len
        batch += 1
        
        # Make prediction when the size increses or it reaches max_batch
        if predict or batch == max_batch_size:
            src_len = t.locateToken('answer', x_in[0]) + 1
            batch, predict, prev_src_len, = 0, 0, 0
            x_cut = x_in[:, :src_len]
            
            pred = x_cut.to(trainer.device)
            pred = sample(model, pred, int(dataset.max_trg+1))
            

            for i in range(x_in.size(0)):
                
                pad, end= t.locateToken('pad', x_in[i]), t.locateToken('end', pred[i])

                x, out = x_in[i][src_len:pad], pred[i][src_len:end]
                x, out = t.tensor2string(x), t.tensor2string(out) 

                correct = 1 if x == out else 0
                results.append(correct)
                
                question =  x_in[i][:src_len-1]
                question_str = tensor2string(dataset.idx, question)
                if not correct:
                    examples.append([question_str, x, out, t.tensor2string(x_cut[i]), t.tensor2string(x_in[i]), t.tensor2string(pred[i]), pad, end])
    
        if max_batches >= 0 and b+1 >= max_batches:
            break
    
    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))
    return examples

In [184]:
# training set: how well did we memorize?
examples = give_exam(test_dataset, batch_size=1, max_batches=-1)
print("Q: %s\nX:%s\nO:%s\n" % (examples[0][0], examples[0][1] , examples[0][2]))

100%|██████████| 66666/66666 [00:12<00:00, 5532.29it/s]

final score: 66249/66431 = 99.73% correct
Q: What is the units digit of 848134?
X:4
O:7






In [187]:
for item in examples:
    print("Question:", item[0])
    print("X:", item[1])
    print("Out:", item[2])



Question: What is the units digit of 848134?
X: 4
Out: 7
Question: What is the units digit of 660595?
X: 5
Out: 7
Question: What is the units digit of 889809?
X: 9
Out: 7
Question: What is the units digit of 745396?
X: 6
Out: 7
Question: What is the units digit of 240695?
X: 5
Out: 7
Question: What is the units digit of 443490?
X: 0
Out: 7
Question: What is the units digit of 943311?
X: 1
Out: 7
Question: What is the units digit of 143439?
X: 9
Out: 7
Question: What is the units digit of 300689?
X: 9
Out: 4
Question: What is the units digit of 211201?
X: 1
Out: 7
Question: What is the units digit of 981840?
X: 0
Out: 7
Question: What is the units digit of 499008?
X: 8
Out: 7
Question: What is the units digit of 195911?
X: 1
Out: 7
Question: What is the units digit of 316829?
X: 9
Out: 2
Question: What is the units digit of 967627?
X: 7
Out: 2
Question: What is the units digit of 449959?
X: 9
Out: 7
Question: What is the units digit of 896952?
X: 2
Out: 7
Question: What is the units dig

In [None]:
# test set: how well did we generalize?
give_exam(train_dataset, batch_size=1024, max_batches=1)

In [None]:
# well that's amusing... our model learned everything except 55 + 45

In [None]:
import itertools as it

In [None]:
f = ['-1', '-1', '2', '1', '1']

it.takewhile(lambda x: x!='2', f)

In [None]:
f