# NanoSchnack Model

## Setup

- Install dependencies.
- Verify that MPS is available (for Apple Silicon GPUs).

In [1]:
import torch
from device import device_info, pick_device, print_device_info

device = pick_device()
info = device_info(device)
print_device_info(info)

True

## Loading a tokenizer with Hugging Face's tokenizer library

- Compare: https://github.com/huggingface/tokenizers
- Tiktokenizer: https://tiktokenizer.vercel.app/?model=gpt2

In [3]:
from tokenizer import load_tokenizer

tokenizer = load_tokenizer()



  from .autonotebook import tqdm as notebook_tqdm


## Instantiating the NanoSchnack model

In [5]:
from gpt import GPT
from autotune import find_max_batch_size
from config import (
    BATCH_SIZE,
    CHECKPOINT_INTERVAL_SECS,
    CHECKPOINT_WARMUP_SECS,
    CONTEXT_LEN,
    EMBED_SIZE,
    HIDDEN_SIZE,
    LEARNING_RATE,
    LOG_INTERVAL_SECS,
    MAX_NEW_TOKENS,
    NUM_HEADS,
    NUM_LAYERS,
    PLOT_INTERVAL_SECS,
    PLOT_WARMUP_SECS,
    TEMPERATURE,
    TOP_K,
    WARMUP_WINDOW_SECS,
    print_training_hyperparams,
)

# add special tokens
tokenizer.add_special_tokens(["[PAD]"])
pad_id = tokenizer.token_to_id("[PAD]")

context_len = CONTEXT_LEN
model = GPT(
    vocab_size=tokenizer.get_vocab_size(),
    embed_size=EMBED_SIZE,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    hidden_size=HIDDEN_SIZE,
    context_len=CONTEXT_LEN,
).to(device).train()
print_training_hyperparams(model)



## Load the Training Data

In [6]:
# Resolve model paths so relative data/checkpoint locations are stable.
try:
    from model import setup_paths
except ModuleNotFoundError:
    from __init__ import setup_paths
model_dir, data_dir, checkpoint_dir = setup_paths()

from datasets.utils.logging import enable_progress_bar, set_verbosity_warning
from loader import ShardedBatchLoader

# Download shards on demand and shuffle within each shard.
set_verbosity_warning()
enable_progress_bar()

# do or not do chunking of the input text, instead of truncating.
if False:
    max_len = context_len
    stride = context_len//4  # overlap; set to 0 for no overlap

    tokenizer.disable_truncation()
    tokenizer.disable_padding()

    # Split long sequences into fixed windows, optionally with overlap.
    def chunk_ids(ids, max_len, stride):
        if len(ids) == 0:
            return []
        step = max_len - stride
        chunks = []
        for start in range(0, len(ids), step):
            chunk = ids[start:start + max_len]
            if len(chunk) == 0:
                continue
            if len(chunk) < max_len:
                chunk = chunk + [pad_id] * (max_len - len(chunk))
            chunks.append(chunk)
            if start + max_len >= len(ids):
                break
        return chunks

    def tokenizer_batch(batch):
        input_ids = []
        attention_mask = [] # marks real tokens (1) vs padding (0)
        for text in batch["result"]:
            ids = tokenizer.encode(text).ids
            for chunk in chunk_ids(ids, max_len=max_len,
                                   stride=stride):
                input_ids.append(chunk)
                attention_mask.append([1 if t != pad_id else 0 for t
                                       in chunk])
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
else:
    # Enable truncation and padding
    tokenizer.enable_truncation(max_length=context_len)
    tokenizer.enable_padding(length=context_len, pad_id=pad_id, pad_token="[PAD]")

    # Wrap Hugging Face tokenizer for batch processing
    def tokenizer_batch(batch):
        token_batch = tokenizer.encode_batch(batch["result"])
        return {
            "input_ids": [e.ids for e in token_batch],
            "attention_mask": [e.attention_mask for e in token_batch], # marks real tokens (1) vs padding (0)
        }

# Tokenize the dataset
tuned_batch_size = find_max_batch_size(
    model,
    vocab_size=tokenizer.get_vocab_size(),
    seq_len=context_len,
    device=device,
    start=BATCH_SIZE,
)
batch_size = tuned_batch_size or BATCH_SIZE
print(f"Tuned batch_size={batch_size}")
sharded_loader = ShardedBatchLoader(
    repo_id="pdelobelle/fineweb-german-edu-mt",
    data_dir=data_dir,
    tokenizer_batch=tokenizer_batch,
    batch_size=batch_size,
    seed=42,
)
print(f"Sharded loader ready ({sharded_loader.num_shards} shards).", flush=True)

## Run the Training

In [None]:
from plot import ascii_loss_plot
from progress import ProgressLogger
from checkpointer import Checkpointer
import math
import torch
import time

# Set up optimizer, learning-rate scheduler, and loss function
epochs = 1 # epochs between 1 and 3 are usually sufficient for good results, rather 1 than 3.
estimated_total_samples = sharded_loader.estimate_total_samples()
steps_per_epoch = math.ceil(estimated_total_samples / batch_size)
total_steps = steps_per_epoch * epochs
print(f"Estimated steps per epoch: {steps_per_epoch} (total {total_steps}).", flush=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
lossFn = torch.nn.CrossEntropyLoss(ignore_index=pad_id)

# The checkpointer will save and load model/optimizer/scheduler states to/from disk.
checkpointer = Checkpointer(checkpoint_dir, model, optimizer, scheduler, device=device)
resume_epoch, resume_step, global_step, resume_position, total_samples = checkpointer.load_latest()

last_ckpt_time = time.time()

# Initialize the progress logger to display training progress and loss
progress = ProgressLogger(
    ascii_loss_plot,
    start_global_step=global_step,
    start_total_samples=total_samples,
    log_interval=LOG_INTERVAL_SECS,
    warmup_plot_interval=PLOT_WARMUP_SECS,
    plot_interval=PLOT_INTERVAL_SECS,
    warmup_window_secs=WARMUP_WINDOW_SECS,
)

last_epoch = resume_epoch
last_step = resume_step
current_position = resume_position
total_samples = total_samples
try:
    print("Starting training loop...", flush=True)
    for epoch in range(resume_epoch, epochs):
        last_epoch = epoch
        loader = sharded_loader.iter_batches(start_position=current_position)
        for step, (batch, current_position, shard_index, shard_len) in enumerate(loader):
            last_step = step

            # Get the input IDs and attention mask, and move them to the GPU
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Next-token prediction
            inputs = input_ids[:, :-1] # everything from the first token except the last
            targets = input_ids[:, 1:] # everything from the second token onward

            # Clear accumulated gradients from the previous step (which torch does automatically otherwise)
            optimizer.zero_grad()

            # Forward pass
            logits = model(inputs, attention_mask=attention_mask[:, :-1])

            # Compute (average) loss of the predicted next tokens and apply backpropagation.
            # reshape to (batch_size * seq_len, vocab_size) and (batch_size * seq_len)
            loss = lossFn(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
            loss.backward()

            # Update weights, then advance the learning-rate schedule.
            optimizer.step()
            scheduler.step()

            # Log progress and plot loss history
            progress.tick(
                loss.item(),
                input_ids.size(0),
                epoch,
                step,
                shard_index=shard_index,
                shard_count=sharded_loader.num_shards,
                shard_len=shard_len,
            )
            total_samples += input_ids.size(0)
            now = time.time()
            ckpt_interval = CHECKPOINT_WARMUP_SECS if (now - last_ckpt_time) < WARMUP_WINDOW_SECS else CHECKPOINT_INTERVAL_SECS
            if now - last_ckpt_time >= ckpt_interval:
                checkpointer.save_latest(
                    epoch,
                    step,
                    progress.global_step,
                    current_position,
                    total_samples,
                )
                last_ckpt_time = now
        current_position = (0, 0)
except KeyboardInterrupt:
    # Save a checkpoint so training can resume from the last completed step.
    print("Interrupted: saving checkpoint...")
    checkpointer.save_latest(
        last_epoch,
        last_step,
        progress.global_step,
        current_position,
        total_samples,
    )
    print("Interrupted: checkpoint saved, exiting.")

Resuming from /Users/sts/Quellen/nanoschnack/checkpoints/latest.pt at epoch 0, step 98.
Epoch 1 (Step 100, Global 100), Loss: 6.3413, Samples/s: 5.2
