# NanoSchnack Model

## Setup

- Install dependencies.
- Verify that MPS is available (for Apple Silicon GPUs).

In [1]:
from pickletools import optimize

import torch

torch.backends.mps.is_available()
torch.backends.mps.is_built()

True

## Trying out MPS

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")


## Loading a tokenizer with Hugging Face's tokenizer library

- Compare: https://github.com/huggingface/tokenizers
- Tiktokenizer: https://tiktokenizer.vercel.app/?model=gpt2

In [3]:
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

tokenizer_path = hf_hub_download(repo_id="openai-community/gpt2", filename="tokenizer.json")


  from .autonotebook import tqdm as notebook_tqdm


### Testing the tokenizer

In [4]:
tokenizer = Tokenizer.from_file(tokenizer_path)
print(tokenizer.encode("Hello, World!").ids)

[15496, 11, 2159, 0]


## Instantiating the NanoSchnack model

In [5]:
from gpt import GPT

# add special tokens
tokenizer.add_special_tokens(["[PAD]"])
pad_id = tokenizer.token_to_id("[PAD]")

context_len = 256
model = GPT(vocab_size=tokenizer.get_vocab_size()).to(device).train()

## Load the Training Data

In [6]:
from datasets import load_dataset
from torch.utils.data import DataLoader

# Load dataset in streaming mode (does not load everything into memory at once)
# Note(sttts): I am using https://huggingface.co/datasets/pdelobelle/fineweb-german-edu-mt.
raw_ds = load_dataset(
    "parquet",
    data_files={"train": "../data/*.parquet"},
    split="train",
    streaming=True,
)

# Shuffle the dataset with a buffer for approximate shuffling
shuffled = raw_ds.shuffle(buffer_size=10_000, seed=42) # lazy shuffle (approximate) with a buffer

# do or not do chunking of the input text, instead of truncating.
if False:
    max_len = context_len
    stride = context_len/4  # overlap; set to 0 for no overlap

    tokenizer.disable_truncation()
    tokenizer.disable_padding()

    # Split long sequences into fixed windows, optionally with overlap.
    def chunk_ids(ids, max_len, stride):
        if len(ids) == 0:
            return []
        step = max_len - stride
        chunks = []
        for start in range(0, len(ids), step):
            chunk = ids[start:start + max_len]
            if len(chunk) == 0:
                continue
            if len(chunk) < max_len:
                chunk = chunk + [pad_id] * (max_len - len(chunk))
            chunks.append(chunk)
            if start + max_len >= len(ids):
                break
        return chunks

    def tokenizer_batch(batch):
        input_ids = []
        attention_mask = [] # marks real tokens (1) vs padding (0)
        for text in batch["result"]:
            ids = tokenizer.encode(text).ids
            for chunk in chunk_ids(ids, max_len=max_len,
                                   stride=stride):
                input_ids.append(chunk)
                attention_mask.append([1 if t != pad_id else 0 for t
                                       in chunk])
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
else:
    # Enable truncation and padding
    tokenizer.enable_truncation(max_length=context_len)
    tokenizer.enable_padding(length=context_len, pad_id=pad_id, pad_token="[PAD]")

    # Wrap Hugging Face tokenizer for batch processing
    def tokenizer_batch(batch):
        token_batch = tokenizer.encode_batch(batch["result"])
        return {
            "input_ids": [e.ids for e in token_batch],
            "attention_mask": [e.attention_mask for e in token_batch], # marks real tokens (1) vs padding (0)
        }

# Shuffle deterministically (only way for streaming datasets)
dataset = shuffled.map(tokenizer_batch, batched=True)

# Set the dataset format to PyTorch tensors
dataset = dataset.with_format(type="torch")

# Tokenize the dataset
batch_size = 32
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

## Run the Training

In [None]:
from plot import ascii_loss_plot
import time
from collections import deque

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10_000)
lossFn = torch.nn.CrossEntropyLoss()

epochs = 1 # epochs between 1 and 3 are usually sufficient for good results, rather 1 than 3.
start_time = time.time()
last_log_time = start_time
last_plot_time = start_time
samples_since_log = 0
loss_history = deque()
for epoch in range(epochs):
    for step, batch in enumerate(loader):
        # Get the input IDs and attention mask, and move them to the GPU
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Next-token prediction
        inputs = input_ids[:, :-1] # everything from the first token except the last
        targets = input_ids[:, 1:] # everything from the second token onward

        # Clear accumulated gradients from the previous step (which torch does automatically otherwise)
        optimizer.zero_grad()

        # Forward pass
        logits = model(inputs, attention_mask=attention_mask[:, :-1])

        # Compute (average) loss of the predicted next tokens and apply backpropagation.
        # reshape to (batch_size * seq_len, vocab_size) and (batch_size * seq_len)
        loss = lossFn(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()

        # Update weights, then advance the learning-rate schedule.
        optimizer.step()
        scheduler.step()

        # Record loss for logging and plotting
        now = time.time()
        loss_history.append((now, loss.item()))
        while loss_history and (now - loss_history[0][0]) > 3600:
            loss_history.popleft()

        # Every 10 seconds, log progress
        samples_since_log += input_ids.size(0)
        if now - last_log_time >= 10:
            elapsed = now - last_log_time
            samples_per_sec = samples_since_log / elapsed if elapsed > 0 else 0.0
            print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}, Samples/s: {samples_per_sec:.1f}")
            last_log_time = now
            samples_since_log = 0

        # Every minute (or every 10 minutes after 10 minutes), plot loss history
        plot_interval = 60 if (now - start_time) < 600 else 600
        if now - last_plot_time >= plot_interval:
            print(ascii_loss_plot(list(loss_history)))
            last_plot_time = now


Epoch 1 (Step 7), Loss: 10.1352, Samples/s: 20.2
Epoch 1 (Step 14), Loss: 9.4491, Samples/s: 20.0
Epoch 1 (Step 21), Loss: 9.0201, Samples/s: 18.7
Epoch 1 (Step 27), Loss: 8.5242, Samples/s: 18.1
Epoch 1 (Step 33), Loss: 8.0164, Samples/s: 17.7
loss (last hour, min 7.8498 max 10.9798)
   10.98  ┤
   10.70  ┼──╮
   10.41  ┤  ╰───╮
   10.13  ┤      ╰──╮
    9.84  ┤         ╰────╮
    9.56  ┤              ╰───╮
    9.27  ┤                  ╰──────╮
    8.99  ┤                         ╰─────────╮
    8.70  ┤                                   ╰──────╮
    8.42  ┤                                          ╰──╮ ╭─╮
    8.13  ┤                                             ╰─╯ ╰──────╮
    7.85  ┤                                                        ╰──
           08:22:45                 08:23:15                  08:23:45
