## Train GPT on gym

Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add.

In [7]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [8]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [9]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [10]:
env = gym.make('DiscreteGridworld-v0')

In [11]:
env.reset()

array([ 4, 10], dtype=uint8)

In [78]:
from torch.utils.data import Dataset
import sys
sys.path.append('../DRQN_pt')
import gym, envs


class GymDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our
    encoding will simply be the n-digit first number, n-digit second number, 
    and (n+1)-digit result, all simply concatenated together. Because each addition
    problem is so structured, there is no need to bother the model with encoding
    +, =, or other tokens. Each possible sequence has the same length, and simply
    contains the raw digits of the addition problem.
    
    As a few examples, the 2-digit problems:
    - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]
    - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]
    etc.
    
    We will also only train GPT on the final (n+1)-digits because the first
    two n-digits are always assumed to be given. So when we give GPT an exam later,
    we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like
    to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]
    in 3 sequential steps.
    
    fun exercise: does it help if the result is asked to be produced in reverse order?
    """

    def __init__(self, split: str, env_name: str ="DiscreteGridworld-v0"):
        self.env = gym.make(env_name)
        self.split = split # train/test
        self.vocab_size = 12 # 12 possible... -1 through 10
        # TODO: should be env.action_space.shape[0] as well instead of 1
        self.block_size = self.env.observation_space.shape[0] * 5
        
        # split up all addition problems into either training data or test data
        # Let's start with 50k samples
        # num = (10**self.ndigit)**2 # total number of possible combinations
        # r = np.random.RandomState(1337) # make deterministic
        # perm = r.permutation(num)
        # num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000
        # self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]

    def __len__(self):
        #return self.ixes.size
        return 50_000

    def __getitem__(self, idx):
        obs = self.env.reset()
        history = [torch.tensor(obs.copy(), dtype=torch.long)]
        for _ in range(5):
            action = env.action_space.sample()
            history.append(torch.tensor([action], dtype=torch.long))
            obs, _, _, _ = env.step(action)
            history.append(torch.tensor(obs.copy(), dtype=torch.long))
        h = torch.cat(([hist for hist in history]))
        x = h[:-1]
        y = h[1:]
        return x, y

        obs = torch.tensor(self.env.reset(), dtype=torch.long)
        action = self.env.action_space.sample()
        next_obs, _, _, _ = self.env.step(action)
        x = torch.cat((obs, torch.tensor([action, next_obs[0]], dtype=torch.long)))
        y = torch.cat((torch.tensor([-100, -100], dtype=torch.long), torch.tensor(next_obs, dtype=torch.long)))
        return x, y



In [79]:
train_dataset = GymDataset(split='train', env_name='DiscreteHorseshoe-v0')
test_dataset = GymDataset(split='test', env_name='DiscreteHorseshoe-v0')

In [85]:
env = gym.make('DiscreteHorseshoe-v0')

In [90]:
env.reset()

array([255, 255], dtype=uint8)

In [87]:
env.reset()

array([1, 5], dtype=uint8)

In [84]:
train_dataset[190] # sample a training instance just to see what one raw example looks like

(tensor([3, 1, 1, 6, 8, 3, 7, 8, 0, 7, 9, 2, 6, 9, 3, 7]),
 tensor([1, 1, 6, 8, 3, 7, 8, 0, 7, 9, 2, 6, 9, 3, 7, 9]))

In [65]:
from mingpt.model import GPT, GPTConfig, GPT1Config

# initialize a baby GPT model
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, 
                  n_layer=2, n_head=4, n_embd=128)
model = GPT(mconf)

01/03/2022 18:57:37 - INFO - mingpt.model -   number of parameters: 4.001280e+05


In [66]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=20, batch_size=512, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(4),
                      num_workers=4)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

epoch 1 iter 97: train loss 0.10494. lr 5.998550e-04: 100%|██████████| 98/98 [00:01<00:00, 86.35it/s]
01/03/2022 18:57:41 - INFO - mingpt.trainer -   test loss: 0.041783
epoch 2 iter 97: train loss 0.01925. lr 5.994139e-04: 100%|██████████| 98/98 [00:01<00:00, 86.77it/s]
01/03/2022 18:57:43 - INFO - mingpt.trainer -   test loss: 0.002651
epoch 3 iter 97: train loss 0.03455. lr 5.986774e-04: 100%|██████████| 98/98 [00:01<00:00, 86.44it/s]
01/03/2022 18:57:46 - INFO - mingpt.trainer -   test loss: 0.001013
epoch 4 iter 97: train loss 0.00785. lr 5.976460e-04: 100%|██████████| 98/98 [00:01<00:00, 81.43it/s]
01/03/2022 18:57:48 - INFO - mingpt.trainer -   test loss: 0.000514
epoch 5 iter 97: train loss 0.00858. lr 5.963208e-04: 100%|██████████| 98/98 [00:01<00:00, 80.11it/s]
01/03/2022 18:57:51 - INFO - mingpt.trainer -   test loss: 0.000469
epoch 6 iter 97: train loss 0.03811. lr 5.947032e-04: 100%|██████████| 98/98 [00:01<00:00, 87.74it/s]
01/03/2022 18:57:53 - INFO - mingpt.trainer -   

In [71]:
from mingpt.utils import sample

sample(model, torch.tensor([0, 9, 3], dtype=torch.long, device=trainer.device)[None, ...], 4)

tensor([[0, 9, 3, 1, 9, 2, 0]], device='cuda:0')

In [34]:
env.reset()

array([ 0, 10], dtype=uint8)

In [72]:
env = gym.make('DiscreteHorseshoe-v0')

In [74]:
env.step(0)

(array([5, 9], dtype=uint8),
 -1.0,
 False,
 {'state': array([5, 9], dtype=uint8)})

In [28]:
x = torch.tensor([1, 2, 0], dtype=torch.long, device=trainer.device)
x.shape

torch.Size([3])

In [None]:
# now let's give the trained model an addition exam
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import sample

def give_exam(dataset, batch_size=32, max_batches=-1):
    
    results = []
    loader = DataLoader(dataset, batch_size=batch_size)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        d1d2 = x[:, :ndigit*2]
        d1d2d3 = sample(model, d1d2, ndigit+1)
        d3 = d1d2d3[:, -(ndigit+1):]
        factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)
        # decode the integers from individual digits
        d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)
        d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)
        d3i_pred = (d3 * factors).sum(1)
        d3i_gt = d1i + d2i
        correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            judge = 'YEP!!!' if correct[i] else 'NOPE'
            if not correct[i]:
                print("GPT claims that %03d + %03d = %03d (gt is %03d; %s)" 
                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))
        
        if max_batches >= 0 and b+1 >= max_batches:
            break

    print("final score: %d/%d = %.2f%% correct" % (np.sum(results), len(results), 100*np.mean(results)))

In [None]:
# training set: how well did we memorize?
give_exam(train_dataset, batch_size=1024, max_batches=10)

In [None]:
# test set: how well did we generalize?
give_exam(test_dataset, batch_size=1024, max_batches=-1)

In [None]:
# well that's amusing... our model learned everything except 55 + 45