## Train GPT on addition

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [3]:
import numpy as np
import torch
import string
import torch.nn as nn
from torch.nn import functional as F

In [6]:
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    """

    def __init__(self, fname, split):
        self.split = split # train/test
        self.vocab = ['pad', 'answer', 'end'] + list(' ' + string.punctuation + string.digits + string.ascii_uppercase + string.ascii_lowercase)
        self.vocab_size = len(self.vocab) # 10 possible digits 0..9
        # Max input characters plus max answer characters
        # self.block_size = 160 + 32
        self.t = {k: v for v, k in enumerate(self.vocab)} # Character to ID
        self.idx = {v: k for k, v in self.t.items()} # ID to Character
        
        # split up all addition problems into either training data or test data
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            self.src = text_list[0:][::2]
            self.trg = text_list[1:][::2]
            self.src_trg = [src+trg for src,trg in zip(self.src,self.trg)]
            self.max_trg = np.ceil((sum(map(len, self.trg)) / len(self.trg)))
        
        self.block_size = len(max(self.src_trg, key=len)) + 1
        data_len = len(self.src) # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(data_len)
       
        num_test = int(data_len*0.1) # 20% of the whole dataset, or only up to 1000
        
        # Sort test data by lenght to batch predictions
        test_data_by_length = []
        for index in perm[num_test:]:
            test_data_by_length.append([index, len(self.src[index])])
        test_data_by_length = sorted(test_data_by_length, key=lambda x: x[1])
        test_data_by_length = [i[0] for i in test_data_by_length]
        
        self.ixes = np.array(test_data_by_length) if split == 'test' else perm[num_test:]
        

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        src = self.src[idx]
        trg = self.trg[idx]

        src_trg = list(src) + ['answer'] + list(trg) + ['end']
        src_trg = [self.t[tok] for tok in src_trg] # convert each character to its token index
        
        # x will be input to GPT and y will be the associated expected outputs
        x = [self.t['pad']] * self.block_size
        y = [self.t['pad']] * self.block_size
  
        x[:len(src_trg[:-1])] = src_trg[:-1]
        y[:len(src_trg[1:])] = src_trg[1:] # predict the next token in the sequence
        y = [-100 if tok == self.t['pad'] else tok for tok in y] # -100 will mask loss to zero

        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long) 
        y[:len(src)] = -100 # we will only train in the output locations. -100 will mask loss to zero
        
        return x, y


In [7]:
# create a dataset 
easy = 'data/numbers__place_value.txt'
medium = 'data/numbers__is_prime.txt'
hard = 'data/numbers__list_prime_factors.txt'

train_dataset = AdditionDataset(fname=easy, split='train')
test_dataset = AdditionDataset(fname=easy, split='test')

In [8]:
# for i in range(0, len(train_dataset)):
#     if len(train_dataset[i][0]) != 52 or len(train_dataset[i][1]) != 52:
#         print(train_dataset.block_size)
#         print(len(train_dataset[i][0]))
#         print(len(train_dataset[i][1]))
#         print(train_dataset[i])

train_dataset[0] # sample a training instance just to see what one raw example looks like

(tensor([68, 79, 72, 91,  3, 80, 90,  3, 91, 79, 76,  3, 84, 80, 83, 83, 80, 86,
         85, 90,  3, 75, 80, 78, 80, 91,  3, 86, 77,  3, 37, 45, 44, 40, 43, 41,
         44, 39, 24,  1, 45,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   45,    2, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]))

In [19]:
from mingpt.model import GPT, GPTConfig, GPT1Config

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
model = GPT(mconf)

11/14/2020 17:36:13 - INFO - mingpt.model -   number of parameters: 4.285440e+05


In [125]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(14+1),
                      num_workers=0)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 1171: train loss 0.02204. lr 5.999895e-04: 100%|██████████| 1172/1172 [37:16<00:00,  1.91s/it]
11/13/2020 11:44:05 - INFO - mingpt.trainer -   test loss: 0.003682


In [52]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample

def give_exam(dataset, batch_size=1, max_batch_size=64, max_batches=-1):
    
    results = []
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    prev_src_len = 0
    predict = 0
    batch = 0
    x_in = 0
    
    for b, (x, y) in enumerate(loader):
        
        src_len = x[0].tolist().index(dataset.t['answer']) + 1
        x_in = leftover if prev_src_len == -1 else x_in
        
        # Concat input source with same length
        if prev_src_len == src_len:
            print(x_in.size())
            print(x.size())
            print(x_in.size())
            x_in = torch.cat((x_in, x), 0)
        elif prev_src_len == 0:
            x_in = x
        else:
            prev_src_len = -1
            predict = 1
            leftover = x
        prev_src_len = src_len
        batch += 1
        
        # Make prediction when the size increses or it reaches max_batch
        if predict or batch == max_batch_size:
            batch = 0
            predict = 0
            prev_src_len = 0
            x_in = x_in[:, :src_len]
            pred = sample(model, x_in, int(dataset.max_trg+1))

            for i in range(x_in.size(0)):
                print(x_in.size())

                pad = -2 if dataset.t['pad'] not in x_in[i].tolist() else x_in[i].tolist().index(dataset.t['pad'])
                x_list = x_in[i][src_len:pad].tolist()
                y_list = pred[i][src_len:pad].tolist()
                x_str = ''.join([dataset.idx[tok] for tok in x_list])
                y_str = ''.join([dataset.idx[tok] for tok in y_list])

                correct = 1 if x_str == y_str else 0
                results.append(correct)
                question =  x_in[i][:src_len-2].tolist()
                question_str = ''.join([dataset.idx[tok] for tok in question])

                print("Q: %s\nP:%s\nG:%s" % (question_str, y_str, x_str))
    
        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

In [None]:
# training set: how well did we memorize?
give_exam(test_dataset, batch_size=1, max_batches=-1)

torch.Size([1, 52])
torch.Size([1, 52])
torch.Size([1, 52])
torch.Size([2, 52])
torch.Size([1, 52])
torch.Size([2, 52])
torch.Size([3, 52])
torch.Size([1, 52])
torch.Size([3, 52])
torch.Size([4, 52])
torch.Size([1, 52])
torch.Size([4, 52])
torch.Size([5, 33])
Q: What is the units digit of 269?
P:
G:
torch.Size([5, 33])
Q: What is the tens digit of 6893?
P:
G:
torch.Size([5, 33])
Q: What is the tens digit of 8332?
P:
G:
torch.Size([5, 33])
Q: What is the tens digit of 4652?
P:
G:
torch.Size([5, 33])
Q: What is the tens digit of 6101?
P:
G:
torch.Size([1, 52])
torch.Size([1, 52])
torch.Size([1, 52])
torch.Size([2, 52])
torch.Size([1, 52])
torch.Size([2, 52])
torch.Size([3, 52])
torch.Size([1, 52])
torch.Size([3, 52])
torch.Size([4, 52])
torch.Size([1, 52])
torch.Size([4, 52])
torch.Size([5, 52])
torch.Size([1, 52])
torch.Size([5, 52])
torch.Size([6, 52])
torch.Size([1, 52])
torch.Size([6, 52])
torch.Size([7, 52])
torch.Size([1, 52])
torch.Size([7, 52])
torch.Size([8, 52])
torch.Size([1, 

In [29]:
# test set: how well did we generalize?
give_exam(train_dataset, batch_size=1024, max_batches=1)

RuntimeError: each element in list of batch should be of equal size

In [None]:
# well that's amusing... our model learned everything except 55 + 45

In [9]:
import itertools as it

In [21]:
f = ['-1', '-1', '2', '1', '1']

it.takewhile(lambda x: x!='2', f)

<itertools.takewhile at 0x7f8192cabb40>

In [22]:
f

['-1', '-1', '2', '1', '1']