A cute little demo showing the simplest usage of babygpt. 

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from babygpt.utils import set_seed
import numpy as np
set_seed(3407)

In [2]:
import pickle

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

In [3]:
import pickle

class OrderedRepeatingDataset(Dataset):
    """ 
    Dataset for the masked-language-modeling problem. E.g. for problem of length 10 with 3 characters (0,1,2), and 3 as a masking token:
    The data always contains the tokens in increasing order, and then repeats. 
    Some tokens as masked using a special token that isn't part of the vocab. 
    Input: '1201201301' -> Output: '1201201201'
    """

    def __init__(self, split, length=6, num_chars = 3, p_mask = 0.2):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_chars = num_chars
        self.p_mask = p_mask
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_chars
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length

    def toy_language_model(self):
        # generate a tensor of size self.length, using num_chars, following a simple correlation structure
        # the output is the first self.length elements a randomly rotated version of {0, 1, 2, .., n-1, 0, 1, 2..., n-1, 0, 1...}
        x = torch.cat([torch.arange(0, self.num_chars)]*(int(np.ceil((self.num_chars + self.length)/self.num_chars))))
        start_ind = int(np.random.random() * self.num_chars)
        return x[start_ind : start_ind+self.length]
        
    def __getitem__(self, idx):
        # generate some random characters as per 
        inp = self.toy_language_model()
        
        # solve the task: MLM
        sol = torch.cat([torch.tensor([-1]), inp[0:-1]])
        
        return inp, sol

In [4]:
# print an example instance of the dataset
train_dataset = OrderedRepeatingDataset('train')
test_dataset = OrderedRepeatingDataset('test')
x, y = train_dataset[0]
print('x','y')
print('-----')
for a, b in zip(x,y):
    print(int(a),int(b))

x y
-----
1 -1
2 1
0 2
1 0
2 1
0 2


In [5]:
len(train_dataset), len(test_dataset)

(10000, 10000)

In [6]:
# create a babyBERT instance
from babygpt.model import GPT
model = GPT(d_embed = 12, n_head = 3, num_layers = 3,  vocab_size = 3, block_size = train_dataset.get_block_size())



In [7]:
train_dataset.get_block_size()

6

In [8]:
# create a Trainer object
from babygpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

running on device cpu


In [9]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)
trainer.run()

iter_dt 0.00ms; iter 0: train loss 1.28633
iter_dt 13.19ms; iter 100: train loss 0.11858
iter_dt 12.08ms; iter 200: train loss 0.05632
iter_dt 14.20ms; iter 300: train loss 0.03390
iter_dt 14.51ms; iter 400: train loss 0.02268
iter_dt 11.44ms; iter 500: train loss 0.01612
iter_dt 13.26ms; iter 600: train loss 0.01194
iter_dt 13.70ms; iter 700: train loss 0.00929
iter_dt 11.38ms; iter 800: train loss 0.00723
iter_dt 12.82ms; iter 900: train loss 0.00582
iter_dt 14.00ms; iter 1000: train loss 0.00478
iter_dt 12.63ms; iter 1100: train loss 0.00400
iter_dt 11.95ms; iter 1200: train loss 0.00335
iter_dt 14.37ms; iter 1300: train loss 0.00294
iter_dt 12.22ms; iter 1400: train loss 0.00244
iter_dt 14.18ms; iter 1500: train loss 0.00212
iter_dt 12.28ms; iter 1600: train loss 0.00184
iter_dt 13.06ms; iter 1700: train loss 0.00163
iter_dt 12.83ms; iter 1800: train loss 0.00140
iter_dt 13.18ms; iter 1900: train loss 0.00129


In [10]:
# now let's perform some evaluation
model.eval();

In [11]:
len(train_dataset)

10000

In [12]:
def eval_split(trainer, split, max_batches):
    dataset = {'train':train_dataset, 'test':test_dataset}[split]
    n = train_dataset.length # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
        if max_batches is not None and b+1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
    return rt.sum()

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50)
    test_score  = eval_split(trainer, 'test',  max_batches=50)

GPT claims that [2, 0, 1, 2, 0, 1] sorted is [0, 2, 1, 0, 2, 1] but gt is [-1, 2, 0, 1, 2, 0]
GPT claims that [2, 0, 1, 2, 0, 1] sorted is [0, 2, 1, 0, 2, 1] but gt is [-1, 2, 0, 1, 2, 0]
GPT claims that [2, 0, 1, 2, 0, 1] sorted is [0, 2, 1, 0, 2, 1] but gt is [-1, 2, 0, 1, 2, 0]
train final score: 0/5000 = 0.00% correct
GPT claims that [2, 0, 1, 2, 0, 1] sorted is [0, 2, 1, 0, 2, 1] but gt is [-1, 2, 0, 1, 2, 0]
GPT claims that [0, 1, 2, 0, 1, 2] sorted is [1, 0, 2, 1, 0, 2] but gt is [-1, 0, 1, 2, 0, 1]
GPT claims that [1, 2, 0, 1, 2, 0] sorted is [2, 1, 0, 2, 1, 0] but gt is [-1, 1, 2, 0, 1, 2]
test final score: 0/5000 = 0.00% correct


In [14]:
# let's run a random given sequence through the model as well
n = train_dataset.length # naugy direct access shrug
inp = torch.tensor([[0,1,2,0,1,2]], dtype=torch.long).to(trainer.device)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.generate(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 1, 2, 0, 1, 2]]
predicted sorted: [[1, 0, 2, 1, 0, 2]]
gt sort         : [0, 0, 1, 1, 2, 2]
matches         : False


In [14]:
cat

tensor([[0, 0, 2, 1, 0, 1, 0, 0, 0, 1, 1, 2]])

In [15]:
n

6

In [35]:
inp = torch.tensor([[1, 2, 2, 1, 1]])
inp

tensor([[1, 2, 2, 1, 1]])

In [36]:
inp.size()

torch.Size([1, 5])

In [37]:
inp[0]

tensor([1, 2, 2, 1, 1])

In [40]:
model(inp)[0].size()

torch.Size([1, 5, 3])

In [41]:
logits = model(inp)[0]

In [42]:
logits.size()

torch.Size([1, 5, 3])

In [45]:
logits.view(-1, logits.size(-1))

tensor([[-2.3310e+00,  6.6854e+00, -4.1210e+00],
        [-2.9664e+00,  5.9037e+00, -2.7988e+00],
        [-3.5032e+00, -3.2251e+00,  6.6771e+00],
        [-4.7313e+00,  5.2067e-03,  4.9808e+00],
        [-3.2557e+00,  5.4126e+00, -2.2338e+00]], grad_fn=<ViewBackward0>)

In [15]:
import peft