# NanoSchnack Model

## Setup

- Install dependencies.
- Verify that MPS is available (for Apple Silicon GPUs).

In [10]:
from pickletools import optimize

import torch

torch.backends.mps.is_available()
torch.backends.mps.is_built()

True

## Trying out MPS

In [22]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")


## Loading a tokenizer with Hugging Face's tokenizer library

- Compare: https://github.com/huggingface/tokenizers
- Tiktokenizer: https://tiktokenizer.vercel.app/?model=gpt2

In [24]:
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

tokenizer_path = hf_hub_download(repo_id="openai-community/gpt2", filename="tokenizer.json")




### Testing the tokenizer

In [25]:
tokenizer = Tokenizer.from_file(tokenizer_path)
print(tokenizer.encode("Hello, World!").ids)

[15496, 11, 2159, 0]


## Instantiating the NanoSchnack model

In [30]:
from gpt import GPT

# add special tokens
tokenizer.add_special_tokens(["[PAD]"])
pad_id = tokenizer.token_to_id("[PAD]")

context_len = 256
model = GPT(vocab_size=tokenizer.get_vocab_size()).to(device).train()

Epoch 1, Loss: 10.98109245300293, Loss per bit: 0.247537
Epoch 2, Loss: 10.98193645477295, Loss per bit: 0.247556
Epoch 3, Loss: 11.02159309387207, Loss per bit: 0.248450
Epoch 4, Loss: 11.055585861206055, Loss per bit: 0.249216
Epoch 5, Loss: 11.043811798095703, Loss per bit: 0.248951
Epoch 6, Loss: 11.025858879089355, Loss per bit: 0.248546
Epoch 7, Loss: 11.04798412322998, Loss per bit: 0.249045
Epoch 8, Loss: 10.961246490478516, Loss per bit: 0.247090
Epoch 9, Loss: 10.976751327514648, Loss per bit: 0.247439
Epoch 10, Loss: 10.938393592834473, Loss per bit: 0.246574


## Load the Training Data

In [60]:
from datasets import load_dataset
from torch.utils.data import DataLoader

# Load dataset in streaming mode (does not load everything into memory at once)
# Note(sttts): I am using https://huggingface.co/datasets/pdelobelle/fineweb-german-edu-mt.
raw_ds = load_dataset(
    "parquet",
    data_files={"train": "../data/*.parquet"},
    split="train",
    streaming=True,
)

# Shuffle the dataset with a buffer for approximate shuffling
shuffled = raw_ds.shuffle(buffer_size=10_000, seed=42) # lazy shuffle (approximate) with a buffer

# do or not do chunking of the input text, instead of truncating.
if False:
    max_len = context_len
    stride = context_len/4  # overlap; set to 0 for no overlap

    tokenizer.disable_truncation()
    tokenizer.disable_padding()

    # Split long sequences into fixed windows, optionally with overlap.
    def chunk_ids(ids, max_len, stride):
        if len(ids) == 0:
            return []
        step = max_len - stride
        chunks = []
        for start in range(0, len(ids), step):
            chunk = ids[start:start + max_len]
            if len(chunk) == 0:
                continue
            if len(chunk) < max_len:
                chunk = chunk + [pad_id] * (max_len - len(chunk))
            chunks.append(chunk)
            if start + max_len >= len(ids):
                break
        return chunks

    def tokenizer_batch(batch):
        input_ids = []
        attention_mask = [] # marks real tokens (1) vs padding (0)
        for text in batch["result"]:
            ids = tokenizer.encode(text).ids
            for chunk in chunk_ids(ids, max_len=max_len,
                                   stride=stride):
                input_ids.append(chunk)
                attention_mask.append([1 if t != pad_id else 0 for t
                                       in chunk])
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
else:
    # Enable truncation and padding
    tokenizer.enable_truncation(max_length=context_len)
    tokenizer.enable_padding(length=context_len, pad_id=pad_id, pad_token="[PAD]")

    # Wrap Hugging Face tokenizer for batch processing
    def tokenizer_batch(batch):
        token_batch = tokenizer.encode_batch(batch["result"])
        return {
            "input_ids": [e.ids for e in token_batch],
            "attention_mask": [e.attention_mask for e in token_batch], # marks real tokens (1) vs padding (0)
        }

# Shuffle deterministically (only way for streaming datasets)
dataset = shuffled.map(tokenizer_batch, batched=True)

# Set the dataset format to PyTorch tensors
dataset = dataset.with_format(type="torch")

# Tokenize the dataset
batch_size = 32
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

## Run the Training

In [59]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10_000)
lossFn = torch.nn.CrossEntropyLoss()

epochs = 1 # epochs between 1 and 3 are usually sufficient for good results, rather 1 than 3.
steps_per_epoch = 10
for epoch in range(epochs):
    for step, batch in enumerate(loader):
        if step >= steps_per_epoch:
            break

        # Get the input IDs and attention mask, and move them to the GPU
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Next-token prediction
        inputs = input_ids[:, :-1] # everything from the first token except the last
        targets = input_ids[:, 1:] # everything from the second token onward

        # Clear accumulated gradients from the previous step (which torch does automatically otherwise)
        optimizer.zero_grad()

        # Forward pass
        logits = model(inputs, attention_mask=attention_mask[:, :-1])

        # Compute (average) loss of the predicted next tokens and apply backpropagation.
        # reshape to (batch_size * seq_len, vocab_size) and (batch_size * seq_len)
        loss = lossFn(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()

        # Update weights, then advance the learning-rate schedule.
        optimizer.step()
        scheduler.step()

    lossPerBit = loss.item() / (16 * 4) / torch.log(torch.tensor(2.0))
    print(f"Epoch {epoch+1}, Loss: {loss.item()}, Loss per bit: {lossPerBit:.6f}")

Epoch 1, Loss: 5.674805641174316, Loss per bit: 0.127922
Epoch 2, Loss: 5.538754940032959, Loss per bit: 0.124855
Epoch 3, Loss: 5.391127109527588, Loss per bit: 0.121527
Epoch 4, Loss: 5.268251419067383, Loss per bit: 0.118758
Epoch 5, Loss: 5.146932601928711, Loss per bit: 0.116023
Epoch 6, Loss: 5.035008907318115, Loss per bit: 0.113500
Epoch 7, Loss: 4.936888217926025, Loss per bit: 0.111288
Epoch 8, Loss: 4.842888832092285, Loss per bit: 0.109169
Epoch 9, Loss: 4.747250556945801, Loss per bit: 0.107013
Epoch 10, Loss: 4.656844139099121, Loss per bit: 0.104975
