## Train GPT on addition

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [3]:
import numpy as np
import torch
import string
import os
from tqdm.auto import tqdm
import torch.nn as nn
from torch.nn import functional as F
%load_ext autoreload
%autoreload 2

In [12]:
test = []
for i in range(7):
    test.append([])

In [14]:
test[0]

[]

In [None]:
class PrepareData:
    """ Tokenizer helper functions """

    def __init__(self, mem_slots):
        self.mem_slots = mem_slots
        self.vocab = ['pad', 'answer', 'end'] + list(' ' + string.punctuation + string.digits + string.ascii_uppercase + string.ascii_lowercase)
        self.vocab_size = len(self.vocab) # 10 possible digits 0..9
        # Max input characters plus max answer characters
        self.src_max_size = 30
        self.max_trg = 5
        self.block_size = 160 + 32
        self.t = {k: v for v, k in enumerate(self.vocab)} # Character to ID
        self.idx = {v: k for k, v in self.t.items()} # ID to Character
    
    def initiate_mem_slot_data(self, fname):
        # split up all addition problems into either training data or test data
        # head_tail = os.path.split(fname) 
        src, trg = [], []
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            src = text_list[0:][::2]
            trg = text_list[1:][::2]
        
        os.remove(fname)
        with open(fname, "a") as file:
            for src, trg in zip(src, trg):
                file.write(src + '\n')
                file.write(trg + '\n')
                for _ in range(self.mem_slots):
                     file.write('\n')
    
    def prepare_data(self, fname):
        # split up all addition problems into either training data or test data
        # head_tail = os.path.split(fname)
        dataset = []
        for _ in range(self.mem_slot + 2):
            dataset.append([])
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            for i in range(self.mem_slot + 2):
                dataset[i] = text_list[i:][::self.mem_slot + 2]
        
        self.max_src = len(max(dataset[0], key=len))
        self.max_trg = len(max(dataset[1], key=len))
        return dataset
    
    def sort_data_by_len(self, indexes, data):
        test_data_by_length = []
        for index in indexes:
            test_data_by_length.append([index, len(data[index])])
        test_data_by_length = sorted(test_data_by_length, key=lambda x: x[1])
        return [i[0] for i in test_data_by_length]
        
    
    def src2Canvas(self, src):
        x = self.t['pad']] * self.src_max_size
        return x[:len(src_trg[:-1])] = src_trg[:-1]
    
    def trg2Canvas(self):
        y = [self.t['pad']] * self.trg_max_size
        return [self.t['pad']] * self.trg_max_size
        
    def tensor2string(self, tensor):
        return ''.join([self.idx[tok] for tok in tensor.tolist()])
    
    def string2digits(self, string):
        return ''.join([self.t[tok] for tok in string])
    
    def mask_padding(self, digits):
        return [-100 if tok == self.t['pad'] else tok for tok in digits]
    
    def mask_question(self, digits, src):
        return digits[:len(src)] = -100

    def locate_token(self, token, tensor):
        return None if self.t[token] not in tensor.tolist() else tensor.tolist().index(self.t[token])

In [4]:
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    """

    def __init__(self, fname, split):
        self.split = split # train/test
        self.vocab = ['pad', 'answer', 'end', 'right', 'wrong'] + list(' ' + string.punctuation + string.digits + string.ascii_uppercase + string.ascii_lowercase)
        self.vocab_size = len(self.vocab) # 10 possible digits 0..9
        # Max input characters plus max answer characters
        # self.block_size = 160 + 32
        self.t = {k: v for v, k in enumerate(self.vocab)} # Character to ID
        self.idx = {v: k for k, v in self.t.items()} # ID to Character
        
        # split up all addition problems into either training data or test data
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            self.src = text_list[0:][::2]
            self.trg = text_list[1:][::2]
            self.src_trg = [src+trg for src,trg in zip(self.src,self.trg)]
            self.max_trg = np.ceil((sum(map(len, self.trg)) / len(self.trg)))
        
        self.block_size = len(max(self.src_trg, key=len)) + 1
        data_len = len(self.src) # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(data_len)
       
        num_test = int(data_len*0.1) # 20% of the whole dataset, or only up to 1000
        
        # Sort test data by lenght to batch predictions
        test_data_by_length = []
        for index in perm[:num_test]:
            test_data_by_length.append([index, len(self.src[index])])
        test_data_by_length = sorted(test_data_by_length, key=lambda x: x[1])
        test_data_by_length = [i[0] for i in test_data_by_length]
        
        self.ixes = np.array(test_data_by_length) if split == 'test' else perm[num_test:]
        

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        src = self.src[idx]
        trg = self.trg[idx]

        src_trg = list(src) + ['answer'] + list(trg) + ['end']
        src_trg = [self.t[tok] for tok in src_trg] # convert each character to its token index
        
        # x will be input to GPT and y will be the associated expected outputs
        x = [self.t['pad']] * self.block_size
        y = [self.t['pad']] * self.block_size
  
        x[:len(src_trg[:-1])] = src_trg[:-1]
        y[:len(src_trg[1:])] = src_trg[1:] # predict the next token in the sequence
        y = [-100 if tok == self.t['pad'] else tok for tok in y] # -100 will mask loss to zero

        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long) 
        y[:len(src)] = -100 # we will only train in the output locations. -100 will mask loss to zero
        
        return x, y


In [5]:
# create a dataset 
easy = 'data/numbers__place_value.txt'
medium = 'data/numbers__is_prime.txt'
hard = 'data/numbers__list_prime_factors.txt'

train_dataset = AdditionDataset(fname=easy, split='train')
test_dataset = AdditionDataset(fname=easy, split='test')

In [84]:
# for i in range(0, len(train_dataset)):
#     if len(train_dataset[i][0]) != 52 or len(train_dataset[i][1]) != 52:
#         print(train_dataset.block_size)
#         print(len(train_dataset[i][0]))
#         print(len(train_dataset[i][1]))
#         print(train_dataset[i])

train_dataset[0] # sample a training instance just to see what one raw example looks like

(tensor([68, 79, 72, 91,  3, 80, 90,  3, 91, 79, 76,  3, 84, 80, 83, 83, 80, 86,
         85, 90,  3, 75, 80, 78, 80, 91,  3, 86, 77,  3, 37, 45, 44, 40, 43, 41,
         44, 39, 24,  1, 45,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   45,    2, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100]))

In [6]:
from mingpt.model import GPT, GPTConfig, GPT1Config

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
model = GPT(mconf)

11/19/2020 10:03:05 - INFO - mingpt.model -   number of parameters: 4.285440e+05


In [7]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(14+1),
                      num_workers=0)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 1171: train loss 0.02204. lr 5.999895e-04: 100%|██████████| 1172/1172 [37:05<00:00,  1.90s/it]
11/19/2020 10:41:20 - INFO - mingpt.trainer -   test loss: 0.003768
11/19/2020 10:41:20 - INFO - mingpt.trainer -   saving model.pth


In [87]:
trainer.save_checkpoint()

11/18/2020 06:13:04 - INFO - mingpt.trainer -   saving model.pth


In [183]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample, Tokenizer

def give_exam(dataset, batch_size=1, max_batch_size=512, max_batches=-1):
    
    t = Tokenizer(dataset)
    results, examples = [], []
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    prev_src_len, predict, batch, x_in = 0, 0, 0, 0
    pbar = tqdm(enumerate(loader), total=len(loader))
    
    for b, (x, y) in pbar:
        
        src_len = t.locateToken('answer', x[0]) 
        x_in = leftover if prev_src_len == -1 else x_in
        
        # Concat input source with same length
        if prev_src_len == src_len:
            x_in = torch.cat((x_in, x), 0)
        elif prev_src_len == 0:
            x_in = x
        else:
            prev_src_len = -1
            predict = 1
            leftover = x
        prev_src_len = src_len
        batch += 1
        
        # Make prediction when the size increses or it reaches max_batch
        if predict or batch == max_batch_size:
            src_len = t.locateToken('answer', x_in[0]) + 1
            batch, predict, prev_src_len, = 0, 0, 0
            x_cut = x_in[:, :src_len]
            
            pred = x_cut.to(trainer.device)
            pred = sample(model, pred, int(dataset.max_trg+1))
            

            for i in range(x_in.size(0)):
                
                pad, end= t.locateToken('pad', x_in[i]), t.locateToken('end', pred[i])

                x, out = x_in[i][src_len:pad], pred[i][src_len:end]
                x, out = t.tensor2string(x), t.tensor2string(out) 

                correct = 1 if x == out else 0
                results.append(correct)
                
                question =  x_in[i][:src_len-1]
                question_str = tensor2string(dataset.idx, question)
                if not correct:
                    examples.append([question_str, x, out, t.tensor2string(x_cut[i]), t.tensor2string(x_in[i]), t.tensor2string(pred[i]), pad, end])
    
        if max_batches >= 0 and b+1 >= max_batches:
            break
    
    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))
    return examples

In [184]:
# training set: how well did we memorize?
examples = give_exam(test_dataset, batch_size=1, max_batches=-1)
print("Q: %s\nX:%s\nO:%s\n" % (examples[0][0], examples[0][1] , examples[0][2]))

100%|██████████| 66666/66666 [00:12<00:00, 5532.29it/s]

final score: 66249/66431 = 99.73% correct
Q: What is the units digit of 848134?
X:4
O:7






In [187]:
for item in examples:
    print("Question:", item[0])
    print("X:", item[1])
    print("Out:", item[2])



Question: What is the units digit of 848134?
X: 4
Out: 7
Question: What is the units digit of 660595?
X: 5
Out: 7
Question: What is the units digit of 889809?
X: 9
Out: 7
Question: What is the units digit of 745396?
X: 6
Out: 7
Question: What is the units digit of 240695?
X: 5
Out: 7
Question: What is the units digit of 443490?
X: 0
Out: 7
Question: What is the units digit of 943311?
X: 1
Out: 7
Question: What is the units digit of 143439?
X: 9
Out: 7
Question: What is the units digit of 300689?
X: 9
Out: 4
Question: What is the units digit of 211201?
X: 1
Out: 7
Question: What is the units digit of 981840?
X: 0
Out: 7
Question: What is the units digit of 499008?
X: 8
Out: 7
Question: What is the units digit of 195911?
X: 1
Out: 7
Question: What is the units digit of 316829?
X: 9
Out: 2
Question: What is the units digit of 967627?
X: 7
Out: 2
Question: What is the units digit of 449959?
X: 9
Out: 7
Question: What is the units digit of 896952?
X: 2
Out: 7
Question: What is the units dig

In [None]:
# test set: how well did we generalize?
give_exam(train_dataset, batch_size=1024, max_batches=1)

In [None]:
# well that's amusing... our model learned everything except 55 + 45

In [None]:
import itertools as it

In [None]:
f = ['-1', '-1', '2', '1', '1']

it.takewhile(lambda x: x!='2', f)

In [None]:
f