In [1]:
import os
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, Dataset
import lightning as L
import finalnlp
from finalnlp.gpt_model import GPT
from finalnlp.replacer import replace_linears_in_pytorch_model
from finalnlp import bitnet1
from finalnlp import bitnet158
from lightning.pytorch import loggers as pl_loggers
%load_ext autoreload

In [5]:
import pickle

L.seed_everything(42)

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

Seed set to 42


In [15]:
# print an example instance of the dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
x, y = train_dataset[0]
for a, b in zip(x,y):
    print(int(a),int(b))

1 -1
2 -1
1 -1
0 -1
2 -1
1 0
0 1
1 1
1 1
1 2
2 2


In [10]:
# create a GPT instance

# model_1b =  replace_linears_in_pytorch_model(GPT(model_config), bitnet1.BitLinear1B)
# model_158b =  replace_linears_in_pytorch_model(GPT(model_config), bitnet158.BitLinear158B)

In [16]:
class LitGPT(L.LightningModule):
    def __init__(self, model_config, linear_replacer = None):
        super().__init__()
        self.model = GPT(model_config)
        if linear_replacer:
            replace_linears_in_pytorch_model(self.model, linear_replacer)
        self.save_hyperparameters()

    def training_step(self, batch: Tensor, batch_idx: int):
        x, y = batch
        z, loss = self.model(x, targets=y)        
        self.log("train_loss", loss, on_step=True)
        return loss
    
    def validation_step(self, batch: Tensor, batch_idx: int):
        x, y = batch
        z, loss = self.model(x, targets=y)        
        self.log("val_loss", loss)
        
        n = test_dataset.length
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = self.model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu()
        mean_correct = correct.float().mean()
        self.log("val_acc", mean_correct)
    
        return loss

    def test_step(self, batch: Tensor, batch_idx: int):
        x, y = batch
        z, loss = self.model(x, targets=y)
        self.log("test_loss", loss)
        return loss
        

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Plain Linear Model

In [15]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

# autoencoder = LitAutoEncoder(Encoder(), Decoder())
L.seed_everything(42, workers=True)
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = LitGPT(model_config)

wandb_logger = pl_loggers.WandbLogger("GPT-Sort-Problem")
wandb_logger.experiment.config.update(model_config)
wandb_logger.experiment.config.update({"problem": "sort", "linear_replacer": "none"})

trainer = L.Trainer(
    callbacks=[EarlyStopping(monitor="train_loss", mode="min")],
    logger=wandb_logger,
    max_epochs=10,
    max_steps=3000,
)
wandb_logger.watch(model)
torch.set_float32_matmul_precision('medium')
trainer.fit(
    model=model,
    train_dataloaders=DataLoader(train_dataset, num_workers=15),
    val_dataloaders=DataLoader(test_dataset, num_workers=15),    
)

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


number of parameters: 0.09M


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [5]:
trainer.validate(model=model, dataloaders=DataLoader(test_dataset, num_workers=15))

NameError: name 'trainer' is not defined

## BitLinear-1B

In [17]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

# autoencoder = LitAutoEncoder(Encoder(), Decoder())
L.seed_everything(42, workers=True)
model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model1b = LitGPT(model_config, linear_replacer=bitnet1.BitLinear1B)

wandb_logger = pl_loggers.WandbLogger("GPT-Sort-Problem-BitNet1B")
wandb_logger.experiment.config.update(model_config)
wandb_logger.experiment.config.update({"problem": "sort", "linear_replacer": "bitnet1b"})

trainer = L.Trainer(
    callbacks=[EarlyStopping(monitor="train_loss", mode="min")],
    logger=wandb_logger,
    max_epochs=10,
    max_steps=3000,
)
wandb_logger.watch(model1b)
torch.set_float32_matmul_precision('medium')
trainer.fit(
    model=model1b,
    train_dataloaders=DataLoader(train_dataset, num_workers=15),
    val_dataloaders=DataLoader(test_dataset, num_workers=15),    
)

Seed set to 42


number of parameters: 0.09M


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcandrewlee14[0m ([33mandrews-org[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [6]:
trainer.validate(model=model1b, dataloaders=DataLoader(test_dataset, num_workers=15))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

In [14]:
def eval_split(model, trainer, split, max_batches):
    model.eval()
    dataset = {'train':train_dataset, 'test':test_dataset}[split]
    n = train_dataset.length # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
        if max_batches is not None and b+1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
    return rt.sum()

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    print("Plain Linear:")
    train_score = eval_split(model, trainer, 'train', max_batches=50)
    test_score  = eval_split(model, trainer, 'test',  max_batches=50)
    print("BitNet1B:")
    train_score = eval_split(model1b, trainer, 'train', max_batches=50)
    test_score  = eval_split(model1b, trainer, 'test',  max_batches=50)

GPT claims that [1, 1, 2, 1, 1, 1] sorted is [1, 1, 1, 1, 2, 2] but gt is [1, 1, 1, 1, 1, 2]
GPT claims that [2, 2, 2, 0, 2, 0] sorted is [0, 0, 1, 2, 2, 2] but gt is [0, 0, 2, 2, 2, 2]
GPT claims that [2, 2, 2, 2, 0, 2] sorted is [0, 1, 2, 2, 2, 2] but gt is [0, 2, 2, 2, 2, 2]
train final score: 4719/5000 = 94.38% correct
GPT claims that [2, 2, 0, 2, 0, 2] sorted is [0, 0, 1, 2, 2, 2] but gt is [0, 0, 2, 2, 2, 2]
GPT claims that [0, 2, 0, 2, 2, 2] sorted is [0, 0, 1, 2, 2, 2] but gt is [0, 0, 2, 2, 2, 2]
GPT claims that [2, 2, 2, 0, 0, 2] sorted is [0, 0, 1, 2, 2, 2] but gt is [0, 0, 2, 2, 2, 2]
test final score: 4693/5000 = 93.86% correct


In [15]:
# let's run a random given sequence through the model as well
n = train_dataset.length # naugy direct access shrug
inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long)
assert inp[0].nelement() == n
with torch.no_grad():
    cat = model.model.generate(inp, n, do_sample=False)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print('input sequence  :', inp.tolist())
print('predicted sorted:', sol_candidate.tolist())
print('gt sort         :', sol.tolist())
print('matches         :', bool((sol == sol_candidate).all()))

input sequence  : [[0, 0, 2, 1, 0, 1]]
predicted sorted: [[0, 0, 0, 1, 1, 2]]
gt sort         : [0, 0, 0, 1, 1, 2]
matches         : True
