# Initial experiment: small vanilla transformer (NanoGPT) trained on plain addition

In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
import math
from pathlib import Path


import torch
import lightning as L
from torch import nn, Tensor
from torch.utils.data import Dataset

from arithmetic_lm.tokenizer import CharTokenizer, Tokenizer
from arithmetic_lm.utils import get_torch_device, set_seed
from arithmetic_lm.constants import DATA_DIR, ROOT_DIR, CHECKPOINTS_DIR

In [3]:
DEVICE = get_torch_device()
print(f"Using device: {DEVICE}")

set_seed(1337)

Using device: mps


In [4]:
SEQ_LEN = 256
BATCH_SIZE = 32
N_LAYERS = 6
N_HEAD = 6
N_EMBD = 384

## NanoGPT model

In [5]:
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [6]:
class NanoGPT(nn.Module):
    """Simple small decoder-only transformer model using nn.TransformerDecoder."""

    def __init__(
        self,
        context_len: int,
        n_embd: int,
        n_head: int,
        n_layers: int,
        vocab_size: int,
        ff_factor: int = 4,
        dropout: float = 0.1,
    ):
        """
        Arguments:
            context_len: context length, i.e. the number of expected features in the input
            n_embd: dimensionality of model embeddings
            n_head: number of heads in the multi-head attention
            n_layers: number of nn.TransformerDecoderLayer layers
            vocab_size: size of the vocabulary
            ff_factor: factor by which to scale the hidden layer dimensionality in the feedforward layer
            dropout: dropout probability
        """

        super().__init__()
        self.context_len = context_len
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.ff_factor = ff_factor
        self.dropout = dropout

        # embedding
        self.embedding = nn.Embedding(vocab_size, n_embd)
        self.pos_encoder = PositionalEncoding(n_embd, max_len=context_len, dropout=dropout)

        # same as decoder layer essentially, but without cross attention
        self.layer = nn.TransformerEncoderLayer(
            d_model=n_embd,
            nhead=n_head,
            dim_feedforward=n_embd * ff_factor,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.layer,
            num_layers=n_layers,
            norm=nn.LayerNorm(n_embd),
        )

        # output to vocab dim
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        # weight tying
        self.lm_head.weight = self.embedding.weight

        # TODO init weights
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len]``
        
        Returns:
            logits: Tensor, shape ``[batch_size, seq_len, vocab_size]``
        """
        x = self.embedding(x)
        x = self.pos_encoder(x)

        x = self.transformer_encoder(
            x,
            is_causal=True,
            mask=nn.Transformer.generate_square_subsequent_mask(
                self.context_len, device=x.device
            ),
        )
        x = self.lm_head(x)
        return x
    
    def param_count(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    @torch.no_grad()
    def generate(
        self,
        idx: Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: int = None,
        stop_token: int = None
    ) -> Tensor:
        """
        Take a conditioning sequence of indices idx (tensor of shape [batch, seq_len]) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """

        assert isinstance(idx, torch.Tensor), "idx must be a torch.Tensor"
        assert idx.dim() == 2, "idx must be a 2D tensor of shape [batch, seq_len]"
        assert idx.size(1) <= self.context_len, "sequence length must be <= context_len"
        assert idx.size(0) == 1, "only batch_size=1 is supported for now"

        for _ in range(max_new_tokens):
            # crop to context_len if necessary
            idx_cond = idx if idx.size(1) <= self.context_len else idx[:, -self.context_len:]

            # logits shape: [batch, seq_len, vocab_size]
            logits = self.forward(idx_cond)

            # get logits at final step and apply temperature
            logits = logits[:, -1, :] / temperature

            # optionally apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("inf")

            # apply softmax
            probs = nn.functional.softmax(logits, dim=-1)

            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)

            # append to the sequence
            idx = torch.cat([idx, next_token], dim=1)

            # stop if stop_token is generated
            if stop_token is not None and next_token.item() == stop_token:
                break
        return idx

## Test one batch overfitting

In [7]:
tokenizer = CharTokenizer()

In [8]:
text = "hello world"
tokens = tokenizer.encode(text)
tokens, "len:", len(tokens), tokenizer.decode(tokens)

([17, 14, 21, 21, 24, 94, 32, 24, 27, 21, 13], 'len:', 11, 'hello world')

In [9]:
# convert to tensor
tokens = torch.tensor(tokens).unsqueeze(0).to(DEVICE)
tokens.shape

torch.Size([1, 11])

### Try overfitting on one batch

In [10]:
# net = NanoGPT(
#     context_len=SEQ_LEN,
#     n_embd=N_EMBD,
#     n_head=N_HEAD,
#     n_layers=N_LAYERS,
#     vocab_size=tokenizer.vocab_size,
# ).to(DEVICE)

In [11]:
# # simplest train loop
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# # create target by shifting tokens by 1 and adding a padding token at the end
# target = torch.cat([tokens[0, 1:], torch.tensor([65]).to(DEVICE)]).unsqueeze(0)
# test_text = "hel"
# test_tokens = tokenizer.encode(test_text)
# test_tokens = torch.tensor(test_tokens).unsqueeze(0).to(DEVICE)

# losses = []

# for i in range(10000):
#     optimizer.zero_grad()
#     y = net(tokens)
#     # y shape: (batch_size, seq_len, vocab_size)
#     loss = criterion(y.view(-1, y.size(-1)), target.view(-1))
#     loss.backward()
#     optimizer.step()
#     losses.append(loss.item())

#     if i % 100 == 0:
#         print(f"[{i}] loss: {loss.item():.5f}  ", test_text + " -> " + tokenizer.decode(net.generate(test_tokens, max_new_tokens=10).squeeze().tolist()))

In [12]:
# import matplotlib.pyplot as plt

# %matplotlib inline

# plt.plot(losses)
# plt.xlabel("iteration")
# plt.ylabel("loss")

In [13]:
# test_prompt = "hello w"
# tokens = tokenizer.encode(test_prompt)
# tokens = torch.tensor(tokens).unsqueeze(0).to(DEVICE)
# print(tokens.shape)

# net.eval()
# generated_tokens = net.generate(tokens, max_new_tokens=10, temperature=1.0, top_k=1)
# print(generated_tokens.shape)

# tokenizer.decode(generated_tokens.squeeze(0).cpu().tolist())

## Dataset

In [14]:
class ArithmeticDataset(Dataset):
    """Concatenate lines in file and split into sequences of length seq_len."""
    # TODO transforms (adding $, formatting, reversing)
    def __init__(self, txtfile: str | Path, tokenizer: Tokenizer, seq_len: int):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        with open(txtfile, "r") as f:
            text = f.read()
        tokens = self.tokenizer.encode(text)
        # make batches of same length (truncate if necessary)
        n_batches = len(tokens) // seq_len
        self.batches = [tokens[i*seq_len:(i+1)*seq_len] for i in range(n_batches)]

    def __len__(self) -> int:
        return len(self.batches)

    def __getitem__(self, idx: int) -> list[int]:
        # return tensors
        return torch.tensor(self.batches[idx])

In [15]:
class ArithmeticEvalDataset(Dataset):
    """Dataset but instead of pure language modeling, we want to evaluate each example (line)"""
    # TODO transforms (adding $, formatting, reversing)
    def __init__(self, txtfile: str | Path, tokenizer: Tokenizer, seq_len: int, format_str: str = None):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        with open(txtfile, "r") as f:
            text = f.readlines()

    def __len__(self) -> int:
        return len(self.batches)

    def __getitem__(self, idx: int) -> list[int]:
        # return tensors
        return torch.tensor(self.batches[idx])

In [16]:
# 10k balanced dataset
total_ds = ArithmeticDataset(DATA_DIR / "add_3digit_bal" / "add_3digit_10k_bal.txt" , CharTokenizer(), seq_len=SEQ_LEN)
print("total:", len(total_ds), "batches")

# train/val split
train_ds, val_ds = torch.utils.data.random_split(total_ds, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
print("train:", len(train_ds))
print("val:", len(val_ds))
print("type(train_ds[0]):", type(train_ds[0]))

del total_ds

total: 468 batches
train: 375
val: 93
type(train_ds[0]): <class 'torch.Tensor'>


In [17]:
ldm = L.LightningDataModule.from_datasets(train_dataset=train_ds, val_dataset=val_ds, batch_size=BATCH_SIZE)

## Lightning module wrapper for model

In [18]:
class LightningNanoGPT(L.LightningModule):
    def __init__(
        self,
        context_len: int,
        n_embd: int,
        n_head: int,
        n_layers: int,
        vocab_size: int,
        ff_factor: int = 4,
        dropout: float = 0.1,
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.99),
        weight_decay: float = 0.1,
    ):
        super().__init__()
        self.model = NanoGPT(
            context_len=context_len,
            n_embd=n_embd,
            n_head=n_head,
            n_layers=n_layers,
            vocab_size=vocab_size,
            ff_factor=ff_factor,
            dropout=dropout,
        )
        self.lr = lr
        self.betas = betas
        self.weight_decay = weight_decay
        self.save_hyperparameters()

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)
    
    def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        # batch: (batch_size, seq_len)
        # split into input and target (shifted by 1)
        x, y = batch[:, :-1], batch[:, 1:]
        # forward pass
        logits = self.model(x)
        # calculate loss
        loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1))
        self.log("train_loss", loss, on_step=True)
        return loss
    
    def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        x, y = batch[:, :-1], batch[:, 1:]
        logits = self.model(x)
        loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.reshape(-1))
        self.log("val_loss", loss, on_step=True)
        return loss
    
    @torch.no_grad()
    def test_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        # TODO accuracy
        raise NotImplementedError("Test step not implemented")
    
    def configure_optimizers(self):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.model.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, torch.nn.MultiheadAttention):
                    # special case for multihead attention
                    decay.add(fpn)

        # subtle: 'embedding.weight' and 'lm_head.weight' are tied, so they
        # will appear in the no_decay and decay sets respectively after the above.
        # In addition, because named_parameters() doesn't return duplicates, it
        # will only return the first occurence, key'd by 'embedding.weight', below.
        # so let's manually remove 'lm_head.weight' from decay set. This will include
        # this tensor into optimization via embedding.weight only, and not decayed.
        decay.remove('lm_head.weight')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.model.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        use_fused = str(DEVICE) == "cuda"
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=self.lr, betas=self.betas, **extra_args)
        return optimizer

    def param_count(self) -> int:
        return self.model.param_count()

In [19]:
lmodel = LightningNanoGPT(
    context_len=SEQ_LEN,
    n_embd=N_EMBD,
    n_head=N_HEAD,
    n_layers=N_LAYERS,
    vocab_size=tokenizer.vocab_size,
    dropout=0.2,
    lr=0.001,
    betas=(0.9, 0.99),
    weight_decay=0.1,
)

In [20]:
run_name = "nanogpt_add_3digit_10k_bal"
run_dir = CHECKPOINTS_DIR / run_name
run_dir.mkdir(exist_ok=True, parents=True)
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min", dirpath=run_dir, filename="{step}-{train_loss:.4f}-{val_loss:.4f}")
trainer = L.Trainer(
    logger=L.pytorch.loggers.WandbLogger(project="msc-thesis-pilot", name=run_name, save_dir=ROOT_DIR, log_model=True),
    callbacks=[checkpoint_callback],
    max_steps=1000,
    val_check_interval=10,
    log_every_n_steps=1,
    gradient_clip_val=1.0,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
trainer.fit(lmodel, ldm)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33miibrahimli[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type    | Params
----------------------------------
0 | model | NanoGPT | 12.5 M
----------------------------------
12.5 M    Trainable params
0         Non-trainable params
12.5 M    Total params
49.842    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/imran/Library/Caches/pypoetry/virtualenvs/msc-thesis-P7I560r2-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/imran/Library/Caches/pypoetry/virtualenvs/msc-thesis-P7I560r2-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 42:  42%|████▏     | 5/12 [00:06<00:09,  0.72it/s, v_num=cmoj] 

/Users/imran/Library/Caches/pypoetry/virtualenvs/msc-thesis-P7I560r2-py3.12/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
