In [1]:
#| default_exp datasets
%load_ext autoreload
%autoreload 2


In [2]:
#| export
import torch
from torch.utils.data import Dataset
import pickle

In [18]:
#| export

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers but we need that no more than with 0.5 probability we have too many repeating numbers.
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample and forget current suite.
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)
        

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        #print(">",cat, x, y)
        return x, y 
    


In [19]:
# print an example instance of the dataset
input_length = 6
num_digits = 4
train_dataset = SortDataset('train', length=input_length, num_digits=num_digits)
test_dataset = SortDataset('test', length=input_length, num_digits=num_digits)
x, y = train_dataset[0]
for i,(a, b) in enumerate(zip(x,y)):
    print(i,":", int(a),int(b))

0 : 0 -1
1 : 1 -1
2 : 3 -1
3 : 1 -1
4 : 3 -1
5 : 3 0
6 : 0 1
7 : 1 1
8 : 1 3
9 : 3 3
10 : 3 3


In [20]:
for i, x in enumerate(train_dataset):
    if i<10:
        print(x)
    else:
        break        

(tensor([3, 1, 2, 3, 0, 0, 0, 0, 1, 2, 3]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  2,  3,  3]))
(tensor([1, 3, 2, 1, 1, 3, 1, 1, 1, 2, 3]), tensor([-1, -1, -1, -1, -1,  1,  1,  1,  2,  3,  3]))
(tensor([2, 3, 0, 0, 1, 2, 0, 0, 1, 2, 2]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  2,  2,  3]))
(tensor([3, 0, 1, 2, 3, 3, 0, 1, 2, 3, 3]), tensor([-1, -1, -1, -1, -1,  0,  1,  2,  3,  3,  3]))
(tensor([1, 0, 1, 2, 2, 1, 0, 1, 1, 1, 2]), tensor([-1, -1, -1, -1, -1,  0,  1,  1,  1,  2,  2]))
(tensor([2, 3, 1, 1, 3, 3, 1, 1, 2, 3, 3]), tensor([-1, -1, -1, -1, -1,  1,  1,  2,  3,  3,  3]))
(tensor([2, 0, 1, 2, 1, 0, 0, 0, 1, 1, 2]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  1,  2,  2]))
(tensor([1, 3, 0, 0, 1, 3, 0, 0, 1, 1, 3]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  1,  3,  3]))
(tensor([0, 0, 2, 2, 1, 2, 0, 0, 1, 2, 2]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  2,  2,  2]))
(tensor([3, 2, 0, 1, 0, 1, 0, 0, 1, 1, 2]), tensor([-1, -1, -1, -1, -1,  0,  0,  1,  1,  2,  3]))


In [22]:
assert train_dataset.get_block_size()==input_length*2-1

In [23]:
#| hide
import nbdev; nbdev.nbdev_export()