A cute little demo showing the simplest usage of minBERT. 

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from minbert.utils import set_seed
import numpy as np
set_seed(3407)

In [2]:
import pickle

class OrderedRepeatingDataset(Dataset):
    """ 
    Dataset for the masked-language-modeling problem. E.g. for problem of length 10 with 3 characters (0,1,2), and 3 as a masking token:
    The data always contains the tokens in increasing order, and then repeats. 
    Some tokens as masked using a special token that isn't part of the vocab. 
    Input: '1201201301' -> Output: '1201201201'
    """

    def __init__(self, split, length=6, num_chars = 3, p_mask = 0.2):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_chars = num_chars
        self.p_mask = p_mask
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_chars
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length

    def toy_language_model(self):
        # generate a tensor of size self.length, using num_chars, following a simple correlation structure
        # the output is the first self.length elements a randomly rotated version of {0, 1, 2, .., n-1, 0, 1, 2..., n-1, 0, 1...}
        x = torch.cat([torch.arange(0, self.num_chars)]*(int(np.ceil((self.num_chars + self.length)/self.num_chars))))
        start_ind = int(np.random.random() * self.num_chars)
        return x[start_ind : start_ind+self.length]
        
    def __getitem__(self, idx):
        # generate some random characters as per 
        inp = self.toy_language_model()
        mask_rands = torch.rand(size=(self.length,)) 
        mask = mask_rands < self.p_mask
        mask_token = self.num_chars 
        inp_masked = mask.long() * mask_token + (1 - mask.long()) * inp
        
        # solve the task: MLM
        sol = inp.clone()
        
        return inp_masked, sol, mask

In [3]:
# print an example instance of the dataset
train_dataset = OrderedRepeatingDataset('train', length = 10, p_mask = 0.75)
test_dataset = OrderedRepeatingDataset('test', length = 10, p_mask = 0.75)
x, y, z = train_dataset[0]
print('x','y','z')
print('-----')
for a, b, c in zip(x,y,z):
    print(int(a),int(b), int(c))

x y z
-----
1 1 0
3 2 1
3 0 1
1 1 0
3 2 1
3 0 1
3 1 1
3 2 1
0 0 0
3 1 1


In [4]:
len(train_dataset), len(test_dataset)

(10000, 10000)

In [5]:
# create a GPT instance
#from mingpt.model import GPT
from minbert.model import BERT

model_config = BERT.get_default_config()
model_config.model_type = 'BERT-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = BERT(model_config)

number of parameters: 0.09M


In [6]:
# create a Trainer object
from minbert.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 500
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

running on device cpu


In [None]:
def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)
trainer.run()

iter_dt 0.00ms; iter 0: train loss 1.09724
iter_dt 25.61ms; iter 100: train loss 0.48035
iter_dt 25.70ms; iter 200: train loss 0.16396
iter_dt 25.65ms; iter 300: train loss 0.09977


In [8]:
# now let's perform some evaluation
model.eval();

In [9]:
def eval_split(trainer, split, max_batches):
    dataset = {'train':train_dataset, 'test':test_dataset}[split]
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    corrects = torch.tensor([], dtype = torch.bool)
    for b, (x, y, z) in enumerate(loader):
        # b is just batch number
        # z is the boolean mask for the batch
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x
        sol = y
        sol_candidate = model.generate_output(inp)
        c = (sol[z] == sol_candidate[z]) 
        corrects = torch.cat([corrects, c])
        if max_batches is not None and b+1 >= max_batches:
            break
    correct = len(corrects[corrects == True])
    total = len(corrects)
    accuracy = correct/total
    print(correct, total, accuracy)
    return accuracy

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50)
    test_score  = eval_split(trainer, 'test',  max_batches=50)

18883 22529 0.8381641439921879
18939 22419 0.8447745216111334


In [25]:
# let's run a random given sequence through the model as well
inp = torch.tensor([[0, 3, 3, 0, 3, 2]], dtype=torch.long).to(trainer.device)
sol = torch.tensor([0, 1, 2, 0, 1, 2], dtype=torch.long)
with torch.no_grad():
    sol_candidate = model.generate_output(inp)
print('input sequence  :', inp.tolist())
print('predicted sequence:', sol_candidate.tolist())
print('ground truth      :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 3, 3, 0, 3, 2]]
predicted sequence: [[0, 1, 2, 0, 1, 2]]
ground truth      : [0, 1, 2, 0, 1, 2]
matches         : True
