## Train GPT on addition

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [None]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [8]:
import numpy as np
import torch
import string
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SequentialSampler

In [15]:
from torch.utils.data import Dataset

class AdditionDataset(SequentialSampler):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    """

    def __init__(self, fname, split):
        self.split = split # train/test
        self.vocab = ['pad', 'answer', 'end'] + list(' ' + string.punctuation + string.digits + string.ascii_uppercase + string.ascii_lowercase)
        self.vocab_size = len(self.vocab) # 10 possible digits 0..9
        # Max input characters plus max answer characters
        # self.block_size = 160 + 32
        self.t = {k: v for v, k in enumerate(self.vocab)} # Character to ID
        self.idx = {v: k for k, v in self.t.items()} # ID to Character
        
        # split up all addition problems into either training data or test data
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            self.src = text_list[0:][::2]
            self.trg = text_list[1:][::2]
            self.src_trg = [src+trg for src,trg in zip(self.src,self.trg)]
            self.max_trg = np.ceil((sum(map(len, self.trg)) / len(self.trg)))
        
        self.block_size = len(max(self.src_trg, key=len)) + 1
        data_len = len(self.src) # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(data_len)
       
        num_test = int(data_len*0.1) # 20% of the whole dataset, or only up to 1000
        
        # Sort test data by lenght to batch predictions
        test_data_by_lenght = []
        for index in perm[num_test:]:
            test_data_by_lenght.append([index, len(self.src[index])])
        test_data_by_lenght = sorted(test_data_by_lenght, key=lambda x: x[1])
        test_data_by_lenght = [i[0] for i in test_data_by_lenght]
        
        self.ixes = np.array(test_data_by_lenght) if split == 'test' else perm[num_test:]
        

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        src = self.src[idx]
        trg = self.trg[idx]

        src_trg = list(src) + ['answer'] + list(trg) + ['end']
        src_trg = [self.t[tok] for tok in src_trg] # convert each character to its token index
        
        # x will be input to GPT and y will be the associated expected outputs
        x = [self.t['pad']] * self.block_size
        y = [self.t['pad']] * self.block_size
  
        x[:len(src_trg[:-1])] = src_trg[:-1]
        y[:len(src_trg[1:])] = src_trg[1:] # predict the next token in the sequence
        y = [-100 if tok == self.t['pad'] else tok for tok in y] # -100 will mask loss to zero

        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long) 
        y[:len(src)] = -100 # we will only train in the output locations. -100 will mask loss to zero
        
        return x, y


In [16]:
# create a dataset 
easy = 'data/numbers__place_value.txt'
medium = 'data/numbers__is_prime.txt'
hard = 'data/numbers__list_prime_factors.txt'

#train_dataset = AdditionDataset(fname=hard, split='train')
test_dataset = AdditionDataset(fname=hard, split='test')

<class 'numpy.ndarray'>
265367
List the prime factors of 737.
List the prime factors of 8812.
List the prime factors of 6140.
List the prime factors of 2981.
List the prime factors of 6375.
List the prime factors of 5746.
List the prime factors of 2611.
List the prime factors of 5050.
List the prime factors of 8255.
List the prime factors of 9419.
List the prime factors of 3095.
List the prime factors of 5648.
List the prime factors of 91530.
List the prime factors of 10555.
List the prime factors of 52123.
List the prime factors of 70014.
List the prime factors of 96701.
List the prime factors of 21729.
List the prime factors of 52256.
List the prime factors of 82715.
List the prime factors of 67608.
List the prime factors of 54251.
List the prime factors of 23251.
List the prime factors of 75291.
List the prime factors of 54774.
List the prime factors of 82729.
List the prime factors of 94351.
List the prime factors of 67747.
List the prime factors of 13489.
List the prime factors of

In [34]:
# for i in range(0, len(train_dataset)):
#     if len(train_dataset[i][0]) != 52 or len(train_dataset[i][1]) != 52:
#         print(train_dataset.block_size)
#         print(len(train_dataset[i][0]))
#         print(len(train_dataset[i][1]))
#         print(train_dataset[i])

# loader = DataLoader(test_dataset, shuffle=False)
# i = 0
# for b, (x, y) in enumerate(loader):
#     test = x.tolist()
#     x_str = ''.join([test_dataset.idx[tok] for tok in test[0]])
#     print(x_str)
#     i += 1
#     if i == 10:
#         break

print(''.join([test_dataset.idx[tok] for tok in test_dataset[0].tolist()[0]]))
# for i in range(0, 10):
#     print(test_dataset[i][0])
    #print(count)
    #print(len(count)) # sample a training instance just to see what one raw example looks like

AttributeError: 'tuple' object has no attribute 'tolist'

In [None]:
from mingpt.model import GPT, GPTConfig, GPT1Config

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
model = GPT(mconf)

In [None]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(14+1),
                      num_workers=0)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

In [None]:
trainer.save_checkpoint()

In [None]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample

def give_exam(dataset, batch_size=32, max_batches=-1):
    
    results = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        
        cut = x[0].tolist().index(dataset.t['answer']) + 1
        pad = -2 if dataset.t['pad'] not in x[0].tolist() else x[0].tolist().index(dataset.t['pad'])
        x_in = x[:, :cut]
        print(x_in)
        pred = sample(model, x_in, int(dataset.max_trg+1))

        for i in range(x.size(0)):

            #x_list = x[:, cut+1:pad].tolist()[0]
            #y_list = pred[:, cut+1:pad].tolist()[0]
            x_list = x[:, cut:pad].tolist()[0]
            y_list = pred[:, cut:pad].tolist()[0]
            x_str = ''.join([dataset.idx[tok] for tok in x_list])
            y_str = ''.join([dataset.idx[tok] for tok in y_list])
            
            correct = 1 if x_str == y_str else 0
     
            results.append(correct)
            judge = 'YEP!!!' if correct else 'NOPE'
            #if not correct:

            question =  x[:, :cut-1].tolist()[0]
            question_str = ''.join([dataset.idx[tok] for tok in question])
                
            print("Q: %s\nP:%s\nG:%s" % (question_str, y_str, x_str))
        
        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

In [None]:
# training set: how well did we memorize?
give_exam(test_dataset, batch_size=1, max_batches=-1)

In [None]:
# test set: how well did we generalize?
give_exam(train_dataset, batch_size=1024, max_batches=1)

In [None]:
# well that's amusing... our model learned everything except 55 + 45

In [None]:
import itertools as it

In [None]:
f = ['-1', '-1', '2', '1', '1']

it.takewhile(lambda x: x!='2', f)

In [None]:
r = np.random.RandomState(1338)
perm = r.permutation(10)

In [None]:
perm