In [18]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
!curl https://storage.googleapis.com/dm-math-dataset/numbers__list_prime_factors.txt > numbers__list_prime_factors.txt
!curl https://storage.googleapis.com/dm-math-dataset/numbers__is_prime.txt > numbers__is_prime.txt
!curl https://storage.googleapis.com/dm-math-dataset/numbers__place_value.txt > numbers__place_value.txt 

In [56]:
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    """
    Returns addition problems of up to some number of digits in the inputs. Recall
    that all GPT cares about are sequences of integers, and completing them according to
    patterns in the data. Therefore, we have to somehow encode addition problems
    as a sequence of integers.
    
    """

    def __init__(self, fname, split):
        self.split = split # train/test
        self.vocab = ['pad', 'answer', 'end'] + list(' ' + string.punctuation + string.digits + string.ascii_uppercase + string.ascii_lowercase)
        self.vocab_size = len(self.vocab) # 10 possible digits 0..9
        # Max input characters plus max answer characters
        self.block_size = 160 + 32
        self.t = {k: v for v, k in enumerate(vocab)} # Character to ID
        self.idx = {v: k for k, v in self.t.items()} # ID to Character
        
        # split up all addition problems into either training data or test data
        with open(fname, "r") as file:
            text = file.read()[:-1] # Excluding the final linebreak
            text_list = text.split('\n')
            self.src = text_list[0:][::2]
            self.trg = text_list[1:][::2]
            self.src_trg = [src+trg for src,trg in zip(self.src,self.trg)]
        
        self.block_size = len(max(self.src_trg, key=len)) 
        data_len = len(self.src) # total number of possible combinations
        r = np.random.RandomState(1337) # make deterministic
        perm = r.permutation(data_len)
        num_test = int(data_len*0.1) # 20% of the whole dataset, or only up to 1000
        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]
        

    def __len__(self):
        return self.ixes.size

    def __getitem__(self, idx):
        # given a problem index idx, first recover the associated a + b
        idx = self.ixes[idx]
        src = self.src[idx]
        trg = self.trg[idx]
        print(src, trg)
        src_trg = list(src) + ['answer'] + list(trg) + ['end']
        
        src_trg = [self.t[tok] for tok in src_trg] # convert each character to its token index
        # x will be input to GPT and y will be the associated expected outputs
        x = [self.t['pad']] * self.block_size
        y = [self.t['pad']] * self.block_size
        x[:len(src_trg[:-1])] = src_trg[:-1]
        y[:len(src_trg[1:])] = src_trg[1:] # predict the next token in the sequence
        
        print(x,y)
        
        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long) 
        y[:len(src)] = -100 # we will only train in the output locations. -100 will mask loss to zero
        return x, y


In [57]:
easy = 'data/numbers__place_value.txt'
medium = 'data/numbers__is_prime.txt'
hard = 'data/numbers__list_prime_factors.txt'

In [58]:
train_dataset = AdditionDataset(fname=easy, split='train')
test_dataset = AdditionDataset(fname=easy, split='test')

In [59]:
train_dataset[0]

What is the millions digit of 19847583? 9
[68, 79, 72, 91, 3, 80, 90, 3, 91, 79, 76, 3, 84, 80, 83, 83, 80, 86, 85, 90, 3, 75, 80, 78, 80, 91, 3, 86, 77, 3, 37, 45, 44, 40, 43, 41, 44, 39, 24, 1, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [79, 72, 91, 3, 80, 90, 3, 91, 79, 76, 3, 84, 80, 83, 83, 80, 86, 85, 90, 3, 75, 80, 78, 80, 91, 3, 86, 77, 3, 37, 45, 44, 40, 43, 41, 44, 39, 24, 1, 45, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


(tensor([68, 79, 72, 91,  3, 80, 90,  3, 91, 79, 76,  3, 84, 80, 83, 83, 80, 86,
         85, 90,  3, 75, 80, 78, 80, 91,  3, 86, 77,  3, 37, 45, 44, 40, 43, 41,
         44, 39, 24,  1, 45,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100,   45,    2,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0]))