# LLaMA from Scratch: Train & Generate on Google Colab

This notebook trains a **~15M parameter LLaMA model** on the [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset, end-to-end:

1. **Train a BPE tokenizer** (SentencePiece, 4096 vocab)
2. **Tokenize the dataset** into memory-mapped binary files
3. **Train the model** with mixed precision, gradient accumulation, cosine LR schedule
4. **Evaluate** on the validation set (loss + perplexity)
5. **Generate text** from prompts with temperature/top-k/top-p sampling
6. **Save & download** the trained checkpoint

**Architecture:** Decoder-only transformer following LLaMA (Meta AI)
- RMSNorm (pre-normalization)
- Rotary Positional Embeddings (RoPE)
- SwiGLU activation in FFN
- Grouped Query Attention (GQA, 6 query heads / 2 KV heads)
- KV cache for efficient inference

**Requirements:** A Colab GPU runtime (T4 or A100). Go to *Runtime > Change runtime type > GPU*.

The entire LLaMA implementation is self-contained in this notebook — no external package needed.

In [None]:
#@title Install Dependencies
!pip install -q torch sentencepiece datasets tqdm

In [None]:
#@title Configuration — Edit these parameters!

# ── Training ─────────────────────────────────────────────────────────────
TRAINING_STEPS = 3000          # Total optimizer steps (~15 min on T4)
BATCH_SIZE = 64                # Sequences per micro-batch
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch = 64 * 4 = 256 sequences
LEARNING_RATE = 3e-4           # Peak LR (after warmup)
MIN_LEARNING_RATE = 3e-5       # Floor LR (10% of peak)
WARMUP_STEPS = 200             # Linear warmup steps
WEIGHT_DECAY = 0.1             # AdamW weight decay
MAX_GRAD_NORM = 1.0            # Gradient clipping

# ── Evaluation & Logging ─────────────────────────────────────────────────
EVAL_INTERVAL = 500            # Evaluate every N steps
EVAL_STEPS = 20                # Batches per evaluation
LOG_INTERVAL = 50              # Print loss every N steps
SAVE_INTERVAL = 1000           # Save checkpoint every N steps

# ── Model Architecture ──────────────────────────────────────────────────
VOCAB_SIZE = 4096
DIM = 384
N_LAYERS = 8
N_HEADS = 6
N_KV_HEADS = 2
MAX_SEQ_LEN = 512
HIDDEN_DIM = 1024

# ── Paths ────────────────────────────────────────────────────────────────
DATA_DIR = "data/"
CHECKPOINT_DIR = "checkpoints/"
TOKENIZER_PREFIX = "data/tokenizer"  # produces data/tokenizer.model
SEED = 42

print(f"Training for {TRAINING_STEPS} steps")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} sequences")
print(f"Tokens per step: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * MAX_SEQ_LEN:,}")

In [None]:
#@title Device Detection
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp
from torch.utils.checkpoint import checkpoint as gradient_checkpoint

if torch.cuda.is_available():
    device = torch.device("cuda")
    props = torch.cuda.get_device_properties(device)
    print(f"GPU: {props.name}")
    print(f"VRAM: {props.total_mem / 1024**3:.1f} GB")
    print(f"Compute Capability: {props.major}.{props.minor}")
    print(f"BF16 Support: {torch.cuda.is_bf16_supported()}")
    print(f"CUDA Version: {torch.version.cuda}")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Device: MPS (Apple Silicon)")
else:
    device = torch.device("cpu")
    print("WARNING: No GPU detected! Training will be very slow.")
    print("Go to Runtime > Change runtime type > GPU")

print(f"\nDevice: {device}")
print(f"PyTorch: {torch.__version__}")

## Core Implementation

The next few cells define the full LLaMA implementation inline:
- **Config** — `ModelConfig` and `TrainConfig` dataclasses
- **Model** — RMSNorm, RoPE, SwiGLU FFN, GQA Attention, Transformer blocks
- **Tokenizer** — SentencePiece BPE wrapper
- **Data Pipeline** — Download, tokenize, DataLoader
- **Utilities** — Device helpers, checkpointing, generation, training helpers

In [None]:
#@title Config Classes
import os
import math
import time
import json
import random
from dataclasses import dataclass, asdict
from typing import Optional, Tuple
from contextlib import nullcontext

import numpy as np


@dataclass
class ModelConfig:
    """Architecture hyperparameters for the LLaMA model."""
    vocab_size: int = 4096
    dim: int = 384
    n_layers: int = 8
    n_heads: int = 6
    n_kv_heads: int = 2
    max_seq_len: int = 512
    hidden_dim: int = 1024
    norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    dropout: float = 0.0
    use_gradient_checkpointing: bool = False
    weight_tying: bool = False

    @property
    def head_dim(self) -> int:
        assert self.dim % self.n_heads == 0
        return self.dim // self.n_heads

    @property
    def n_kv_groups(self) -> int:
        assert self.n_heads % self.n_kv_heads == 0
        return self.n_heads // self.n_kv_heads

    def validate(self) -> None:
        assert self.dim > 0 and self.n_layers > 0
        assert self.n_heads > 0 and self.n_kv_heads > 0
        assert self.n_kv_heads <= self.n_heads
        assert self.dim % self.n_heads == 0
        assert self.n_heads % self.n_kv_heads == 0
        assert self.head_dim % 2 == 0
        assert self.vocab_size > 0 and self.max_seq_len > 0
        assert self.hidden_dim > 0
        assert 0.0 <= self.dropout < 1.0

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, d: dict) -> "ModelConfig":
        return cls(**d)

    def save(self, path: str) -> None:
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        with open(path, "w") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def load(cls, path: str) -> "ModelConfig":
        with open(path, "r") as f:
            return cls.from_dict(json.load(f))


print("Config classes defined.")

In [None]:
#@title Model Architecture (RMSNorm, RoPE, SwiGLU, GQA, LLaMA)

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms_inv = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
        return (x.float() * rms_inv).type_as(x) * self.weight


def precompute_rope_frequencies(
    head_dim: int, max_seq_len: int, theta: float = 10000.0,
    device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Precompute cos/sin tables for Rotary Positional Embeddings."""
    assert head_dim % 2 == 0
    dim_indices = torch.arange(0, head_dim, 2, device=device).float()
    freqs = 1.0 / (theta ** (dim_indices / head_dim))
    positions = torch.arange(max_seq_len, device=device).float()
    angles = torch.outer(positions, freqs)
    return angles.cos(), angles.sin()


def apply_rotary_embeddings(
    x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor,
) -> torch.Tensor:
    """Apply RoPE rotation to query or key tensors."""
    x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    x_even = x_reshaped[..., 0]
    x_odd = x_reshaped[..., 1]
    cos = freqs_cos.unsqueeze(0).unsqueeze(2)
    sin = freqs_sin.unsqueeze(0).unsqueeze(2)
    x_even_rot = x_even * cos - x_odd * sin
    x_odd_rot = x_even * sin + x_odd * cos
    x_rotated = torch.stack([x_even_rot, x_odd_rot], dim=-1).flatten(-2)
    return x_rotated.type_as(x)


class FeedForward(nn.Module):
    """SwiGLU Feed-Forward Network."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.w_gate = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w_down = nn.Linear(config.hidden_dim, config.dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))


class Attention(nn.Module):
    """Grouped Query Attention with KV Cache."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.head_dim = config.head_dim
        self.n_kv_groups = config.n_kv_groups
        self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
        self.wk = nn.Linear(config.dim, config.n_kv_heads * config.head_dim, bias=False)
        self.wv = nn.Linear(config.dim, config.n_kv_heads * config.head_dim, bias=False)
        self.wo = nn.Linear(config.n_heads * config.head_dim, config.dim, bias=False)

    def forward(self, x, freqs_cos, freqs_sin, mask=None, kv_cache=None, start_pos=0):
        batch_size, seq_len, _ = x.shape
        q = self.wq(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.wk(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        v = self.wv(x).view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        q = apply_rotary_embeddings(q, freqs_cos, freqs_sin)
        k = apply_rotary_embeddings(k, freqs_cos, freqs_sin)

        if kv_cache is not None:
            cached_k, cached_v = kv_cache
            k = torch.cat([cached_k, k], dim=1)
            v = torch.cat([cached_v, v], dim=1)
        new_kv_cache = (k, v)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        if self.n_kv_heads != self.n_heads:
            k = k.repeat_interleave(self.n_kv_groups, dim=1)
            v = v.repeat_interleave(self.n_kv_groups, dim=1)

        if mask is not None:
            output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        else:
            output = F.scaled_dot_product_attention(q, k, v, is_causal=(seq_len > 1))

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.wo(output), new_kv_cache


class TransformerBlock(nn.Module):
    """One LLaMA transformer decoder layer."""
    def __init__(self, layer_id: int, config: ModelConfig):
        super().__init__()
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
        self.attention = Attention(config)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        self.feed_forward = FeedForward(config)

    def forward(self, x, freqs_cos, freqs_sin, mask=None, kv_cache=None, start_pos=0):
        attn_output, new_kv_cache = self.attention(
            self.attention_norm(x), freqs_cos, freqs_sin, mask, kv_cache, start_pos
        )
        x = x + attn_output
        x = x + self.feed_forward(self.ffn_norm(x))
        return x, new_kv_cache


class LLaMA(nn.Module):
    """Complete LLaMA decoder-only transformer language model."""
    def __init__(self, config: ModelConfig):
        super().__init__()
        config.validate()
        self.config = config
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.layers = nn.ModuleList([
            TransformerBlock(layer_id=i, config=config)
            for i in range(config.n_layers)
        ])
        self.norm = RMSNorm(config.dim, config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        if config.weight_tying:
            self.output.weight = self.tok_embeddings.weight

        freqs_cos, freqs_sin = precompute_rope_frequencies(
            config.head_dim, config.max_seq_len, config.rope_theta
        )
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        self.apply(self._init_weights)
        for layer in self.layers:
            scale = 1.0 / math.sqrt(2 * config.n_layers)
            nn.init.normal_(layer.attention.wo.weight, mean=0.0, std=0.02 * scale)
            nn.init.normal_(layer.feed_forward.w_down.weight, mean=0.0, std=0.02 * scale)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens, targets=None, kv_caches=None, start_pos=0):
        batch_size, seq_len = tokens.shape
        h = self.tok_embeddings(tokens)
        freqs_cos = self.freqs_cos[start_pos: start_pos + seq_len]
        freqs_sin = self.freqs_sin[start_pos: start_pos + seq_len]

        use_cache = kv_caches is not None
        new_kv_caches = [] if use_cache else None
        for i, layer in enumerate(self.layers):
            layer_kv_cache = kv_caches[i] if use_cache else None
            if self.config.use_gradient_checkpointing and self.training:
                def create_custom_forward(module):
                    def custom_forward(h, freqs_cos, freqs_sin):
                        return module(h, freqs_cos, freqs_sin, None, None, 0)
                    return custom_forward
                h_out, _ = gradient_checkpoint(
                    create_custom_forward(layer), h, freqs_cos, freqs_sin,
                    use_reentrant=False,
                )
                h = h_out
                if use_cache:
                    new_kv_caches.append(None)
            else:
                h, new_kv = layer(h, freqs_cos, freqs_sin, None, layer_kv_cache, start_pos)
                if use_cache:
                    new_kv_caches.append(new_kv)

        h = self.norm(h)
        logits = self.output(h)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1,
            )

        if use_cache:
            return logits, loss, new_kv_caches
        return logits, loss

    def configure_optimizers(self, learning_rate, weight_decay, betas, device):
        decay_params = []
        no_decay_params = []
        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue
            if param.dim() >= 2:
                decay_params.append(param)
            else:
                no_decay_params.append(param)

        param_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": no_decay_params, "weight_decay": 0.0},
        ]
        n_decay = sum(p.numel() for p in decay_params)
        n_no_decay = sum(p.numel() for p in no_decay_params)
        print(f"Optimizer: {n_decay:,} decay params, {n_no_decay:,} no-decay params")

        use_fused = device.type == "cuda"
        return torch.optim.AdamW(param_groups, lr=learning_rate, betas=betas, fused=use_fused)


print("Model architecture defined.")

In [None]:
#@title Tokenizer
import sentencepiece as spm


def train_tokenizer(
    input_file: str, model_prefix: str, vocab_size: int = 4096,
) -> str:
    """Train a SentencePiece BPE tokenizer."""
    os.makedirs(os.path.dirname(model_prefix) or ".", exist_ok=True)
    print(f"Training tokenizer (vocab_size={vocab_size})...")
    spm.SentencePieceTrainer.Train(
        input=input_file, model_prefix=model_prefix, model_type="bpe",
        vocab_size=vocab_size, character_coverage=1.0, byte_fallback=True,
        num_threads=4, max_sentence_length=4192, split_digits=True,
        allow_whitespace_only_pieces=True, normalization_rule_name="identity",
        remove_extra_whitespaces=False, unk_id=0, bos_id=1, eos_id=2, pad_id=-1,
    )
    model_path = f"{model_prefix}.model"
    sp = spm.SentencePieceProcessor()
    sp.Load(model_path)
    print(f"Tokenizer trained: {model_path} (vocab={sp.GetPieceSize()})")
    # Verify roundtrip
    test = "Once upon a time, there was a little cat."
    assert sp.Decode(sp.Encode(test)) == test, "Roundtrip failed!"
    print("Roundtrip test: PASSED")
    return model_path


class Tokenizer:
    """Wrapper around a trained SentencePiece model."""
    def __init__(self, model_path: str):
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Tokenizer model not found: {model_path}")
        self._sp = spm.SentencePieceProcessor()
        self._sp.Load(model_path)

    def encode(self, text: str, bos: bool = True, eos: bool = True) -> list:
        tokens = self._sp.Encode(text)
        if bos:
            tokens = [self.bos_id] + tokens
        if eos:
            tokens = tokens + [self.eos_id]
        return tokens

    def decode(self, tokens: list) -> str:
        return self._sp.Decode(tokens)

    @property
    def vocab_size(self) -> int:
        return self._sp.GetPieceSize()

    @property
    def bos_id(self) -> int:
        return self._sp.bos_id()

    @property
    def eos_id(self) -> int:
        return self._sp.eos_id()

    @property
    def unk_id(self) -> int:
        return self._sp.unk_id()

    @property
    def pad_id(self) -> int:
        pad = self._sp.pad_id()
        return pad if pad >= 0 else -1

    def id_to_piece(self, token_id: int) -> str:
        return self._sp.IdToPiece(token_id)

    def piece_to_id(self, piece: str) -> int:
        return self._sp.PieceToId(piece)

    def __len__(self) -> int:
        return self.vocab_size


print("Tokenizer defined.")

In [None]:
#@title Data Pipeline
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm


def download_tinystories(data_dir: str) -> str:
    """Download TinyStories train split and export to text file."""
    from datasets import load_dataset
    os.makedirs(data_dir, exist_ok=True)
    text_file = os.path.join(data_dir, "tinystories_train.txt")
    if os.path.exists(text_file):
        print(f"Training text already exists: {text_file}")
        return text_file
    print("Downloading TinyStories dataset...")
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    print(f"Downloaded {len(dataset):,} stories")
    with open(text_file, "w", encoding="utf-8") as f:
        for example in tqdm(dataset, desc="Exporting"):
            text = example["text"].strip()
            if text:
                f.write(text + "\n")
    print(f"Saved: {text_file} ({os.path.getsize(text_file) / 1024**2:.1f} MB)")
    return text_file


def download_tinystories_val(data_dir: str) -> str:
    """Download TinyStories validation split."""
    from datasets import load_dataset
    os.makedirs(data_dir, exist_ok=True)
    text_file = os.path.join(data_dir, "tinystories_val.txt")
    if os.path.exists(text_file):
        print(f"Validation text already exists: {text_file}")
        return text_file
    print("Downloading TinyStories validation split...")
    dataset = load_dataset("roneneldan/TinyStories", split="validation")
    print(f"Downloaded {len(dataset):,} validation stories")
    with open(text_file, "w", encoding="utf-8") as f:
        for example in tqdm(dataset, desc="Exporting val"):
            text = example["text"].strip()
            if text:
                f.write(text + "\n")
    return text_file


def tokenize_and_save(
    text_file: str, output_bin: str, tokenizer: Tokenizer, chunk_size: int = 10000,
) -> int:
    """Tokenize a text file and save as memory-mapped binary (uint16)."""
    if os.path.exists(output_bin):
        data = np.memmap(output_bin, dtype=np.uint16, mode="r")
        print(f"Tokenized data already exists: {output_bin} ({len(data):,} tokens)")
        return len(data)
    print(f"Tokenizing {text_file} -> {output_bin}")
    with open(text_file, "r", encoding="utf-8") as f:
        n_lines = sum(1 for _ in f)
    all_tokens = []
    with open(text_file, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=n_lines, desc="Tokenizing"):
            text = line.strip()
            if not text:
                continue
            tokens = tokenizer.encode(text, bos=False, eos=True)
            all_tokens.extend(tokens)
    token_array = np.array(all_tokens, dtype=np.uint16)
    os.makedirs(os.path.dirname(output_bin) or ".", exist_ok=True)
    token_array.tofile(output_bin)
    print(f"Saved: {output_bin} ({len(token_array):,} tokens, {os.path.getsize(output_bin) / 1024**2:.1f} MB)")
    return len(token_array)


def prepare_data(data_dir: str, tokenizer: Tokenizer) -> tuple:
    """Download and tokenize train/val splits. Returns (train_bin, val_bin) paths."""
    train_bin = os.path.join(data_dir, "train.bin")
    val_bin = os.path.join(data_dir, "val.bin")
    if not os.path.exists(train_bin):
        train_text = download_tinystories(data_dir)
        tokenize_and_save(train_text, train_bin, tokenizer)
    else:
        n = os.path.getsize(train_bin) // 2
        print(f"Train data ready: {train_bin} ({n:,} tokens)")
    if not os.path.exists(val_bin):
        val_text = download_tinystories_val(data_dir)
        tokenize_and_save(val_text, val_bin, tokenizer)
    else:
        n = os.path.getsize(val_bin) // 2
        print(f"Val data ready: {val_bin} ({n:,} tokens)")
    return train_bin, val_bin


class TokenDataset(Dataset):
    """Dataset that reads from a pre-tokenized binary file."""
    def __init__(self, data_path: str, seq_len: int):
        super().__init__()
        self.seq_len = seq_len
        self.data = np.memmap(data_path, dtype=np.uint16, mode="r")
        self.n_tokens = len(self.data)
        print(f"TokenDataset: {data_path} ({self.n_tokens:,} tokens, seq_len={seq_len})")

    def __len__(self) -> int:
        return self.n_tokens - self.seq_len - 1

    def __getitem__(self, idx: int):
        chunk = self.data[idx: idx + self.seq_len + 1].astype(np.int64)
        x = torch.from_numpy(chunk[:-1])
        y = torch.from_numpy(chunk[1:])
        return x, y


def create_dataloader(
    data_path: str, seq_len: int, batch_size: int,
    shuffle: bool = True, num_workers: int = 2, pin_memory: bool = True,
) -> DataLoader:
    """Create a DataLoader for training or evaluation."""
    dataset = TokenDataset(data_path, seq_len)
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory, drop_last=True,
    )


print("Data pipeline defined.")

In [None]:
#@title Utilities: Device, Checkpoints, Generation, Training Helpers

# ── Device Utilities ─────────────────────────────────────────────────────

def get_dtype(requested: str, device: torch.device) -> torch.dtype:
    """Resolve dtype string to optimal torch.dtype for the device."""
    if requested == "auto":
        if device.type == "cuda":
            return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        else:
            return torch.float32
    return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[requested]


def get_autocast_context(device: torch.device, dtype: torch.dtype):
    """Return the appropriate AMP autocast context manager."""
    if device.type == "cuda" and dtype in (torch.float16, torch.bfloat16):
        return torch.amp.autocast(device_type="cuda", dtype=dtype)
    elif device.type == "mps":
        return torch.amp.autocast(device_type="mps", dtype=torch.float16)
    return nullcontext()


def get_grad_scaler(device: torch.device, dtype: torch.dtype):
    """Create GradScaler if needed (fp16 on CUDA)."""
    if dtype == torch.float16 and device.type == "cuda":
        return torch.amp.GradScaler(device="cuda")
    return None


def get_memory_usage(device: torch.device) -> dict:
    if device.type == "cuda":
        return {
            "allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
            "reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
        }
    return {"allocated_mb": 0.0, "reserved_mb": 0.0}


# ── Reproducibility ──────────────────────────────────────────────────────

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ── Model Diagnostics ────────────────────────────────────────────────────

def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def print_model_summary(model: nn.Module) -> str:
    lines = ["=" * 65, "Model Parameter Summary", "=" * 65]
    lines.append(f"{'Name':<40} {'Params':>12} {'%':>7}")
    lines.append("-" * 65)
    param_list = [(name, p) for name, p in model.named_parameters()]
    grand_total = sum(p.numel() for _, p in param_list)
    trainable = 0
    for name, param in param_list:
        n = param.numel()
        if param.requires_grad:
            trainable += n
        pct = 100.0 * n / grand_total if grand_total > 0 else 0
        lines.append(f"  {name:<38} {n:>12,d} ({pct:>5.1f}%)")
    lines.append("-" * 65)
    lines.append(f"  {'TOTAL (trainable)':<38} {trainable:>12,d}")
    lines.append(f"  {'Memory (fp32)':<38} {grand_total * 4 / 1024**2:>10.1f} MB")
    lines.append(f"  {'Memory (fp16/bf16)':<38} {grand_total * 2 / 1024**2:>10.1f} MB")
    lines.append("=" * 65)
    summary = "\n".join(lines)
    print(summary)
    return summary


# ── Checkpoints ──────────────────────────────────────────────────────────

def save_checkpoint(model, optimizer, step, val_loss, model_config, train_config, path):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "step": step, "val_loss": val_loss,
        "model_config": model_config, "train_config": train_config,
        "rng_state_python": random.getstate(),
        "rng_state_numpy": np.random.get_state(),
        "rng_state_torch": torch.random.get_rng_state(),
    }
    if torch.cuda.is_available():
        checkpoint["rng_state_cuda"] = torch.cuda.get_rng_state_all()
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path} (step {step}, val_loss {val_loss:.4f})")


def load_checkpoint(path, model, optimizer=None, device=None):
    map_location = device if device else "cpu"
    checkpoint = torch.load(path, map_location=map_location, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    if "rng_state_python" in checkpoint:
        random.setstate(checkpoint["rng_state_python"])
    if "rng_state_numpy" in checkpoint:
        np.random.set_state(checkpoint["rng_state_numpy"])
    if "rng_state_torch" in checkpoint:
        torch.random.set_rng_state(checkpoint["rng_state_torch"])
    if "rng_state_cuda" in checkpoint and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(checkpoint["rng_state_cuda"])
    info = {
        "step": checkpoint.get("step", 0),
        "val_loss": checkpoint.get("val_loss", float("inf")),
        "model_config": checkpoint.get("model_config", {}),
    }
    print(f"Checkpoint loaded: {path} (step {info['step']})")
    return info


# ── Generation ────────────────────────────────────────────────────────────

def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
    probs_sorted, sorted_indices = torch.sort(probs, descending=True)
    cumsum = torch.cumsum(probs_sorted, dim=-1)
    mask = cumsum - probs_sorted > p
    probs_sorted[mask] = 0.0
    probs_sorted /= probs_sorted.sum()
    sampled_idx = torch.multinomial(probs_sorted, num_samples=1)
    return sorted_indices[sampled_idx]


def _sample_token(logits, temperature, top_k, top_p):
    logits = logits.squeeze(0)
    if temperature == 0.0:
        return logits.argmax()
    logits = logits / temperature
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))
        kth_value = torch.topk(logits, top_k).values[-1]
        logits[logits < kth_value] = float("-inf")
    probs = F.softmax(logits, dim=-1)
    if top_p < 1.0:
        return sample_top_p(probs, top_p)
    return torch.multinomial(probs, num_samples=1).squeeze(0)


@torch.inference_mode()
def generate(
    model, tokenizer, prompt, max_new_tokens=200,
    temperature=0.8, top_k=40, top_p=0.9, device=None,
) -> str:
    """Generate text from a prompt using KV-cached autoregressive decoding."""
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    prompt_tokens = tokenizer.encode(prompt, bos=True, eos=False)
    tokens = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
    prompt_len = tokens.shape[1]

    # Prefill: process entire prompt
    n_layers = model.config.n_layers
    logits, _, kv_caches = model(tokens, kv_caches=[None] * n_layers, start_pos=0)
    next_logits = logits[:, -1, :]
    cur_pos = prompt_len

    # Decode: generate one token at a time
    generated_tokens = []
    for _ in range(max_new_tokens):
        next_token = _sample_token(next_logits, temperature, top_k, top_p)
        generated_tokens.append(next_token.item())
        if next_token.item() == tokenizer.eos_id:
            break
        new_token_tensor = next_token.unsqueeze(0).unsqueeze(0)
        logits, _, kv_caches = model(
            new_token_tensor, kv_caches=kv_caches, start_pos=cur_pos,
        )
        next_logits = logits[:, -1, :]
        cur_pos += 1

    all_tokens = prompt_tokens[1:] + generated_tokens  # Skip BOS
    return tokenizer.decode(all_tokens)


# ── Training Helpers ──────────────────────────────────────────────────────

def get_lr(step, warmup_steps, max_steps, learning_rate, min_learning_rate):
    """Cosine annealing with linear warmup."""
    if step < warmup_steps:
        return learning_rate * (step / warmup_steps)
    if step >= max_steps:
        return min_learning_rate
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    return min_learning_rate + (learning_rate - min_learning_rate) * 0.5 * (1.0 + math.cos(math.pi * progress))


@torch.no_grad()
def evaluate(model, val_dataloader, device, autocast_ctx, max_steps=20):
    """Compute average validation loss."""
    model.eval()
    total_loss = 0.0
    n_steps = 0
    for x, y in val_dataloader:
        if n_steps >= max_steps:
            break
        x, y = x.to(device), y.to(device)
        with autocast_ctx:
            _, loss = model(x, targets=y)
        total_loss += loss.item()
        n_steps += 1
    model.train()
    return total_loss / max(n_steps, 1)


print("All utilities defined.")

---
## Step 1: Train Tokenizer

We train a **BPE tokenizer** (Byte-Pair Encoding) with SentencePiece on the TinyStories training data.

- **Vocab size:** 4096 tokens (small, matching our tiny model)
- **Byte fallback:** Unknown characters are encoded as UTF-8 bytes (no `<unk>` tokens)
- **Special tokens:** `<s>` (BOS, id=1), `</s>` (EOS, id=2)

In [None]:
#@title Step 1: Train Tokenizer

# Download TinyStories training data (needed for tokenizer training)
train_text_path = download_tinystories(DATA_DIR)

# Train the tokenizer
tokenizer_model_path = TOKENIZER_PREFIX + ".model"
if os.path.exists(tokenizer_model_path):
    print(f"Tokenizer already exists: {tokenizer_model_path}")
else:
    train_tokenizer(
        input_file=train_text_path,
        model_prefix=TOKENIZER_PREFIX,
        vocab_size=VOCAB_SIZE,
    )

# Load and verify
tokenizer = Tokenizer(tokenizer_model_path)
print(f"\nTokenizer loaded: vocab_size={tokenizer.vocab_size}")

# Test roundtrip
test_text = "Once upon a time, there was a little cat."
tokens = tokenizer.encode(test_text, bos=True, eos=True)
decoded = tokenizer.decode(tokens)
print(f"Encode: '{test_text}'")
print(f"  -> {tokens[:15]}... ({len(tokens)} tokens)")
print(f"Decode: '{decoded}'")
assert decoded.strip() == test_text, "Roundtrip failed!"
print("Roundtrip: PASSED")

In [None]:
#@title Step 2: Prepare Data (Tokenize to Binary)

# Download val split and tokenize both splits to .bin files
train_bin, val_bin = prepare_data(DATA_DIR, tokenizer)

# Print dataset stats
train_tokens = os.path.getsize(train_bin) // 2  # uint16 = 2 bytes
val_tokens = os.path.getsize(val_bin) // 2
print(f"\nTrain tokens: {train_tokens:,}")
print(f"Val tokens:   {val_tokens:,}")
print(f"Train file:   {os.path.getsize(train_bin) / 1024**2:.1f} MB")
print(f"Val file:     {os.path.getsize(val_bin) / 1024**2:.1f} MB")

In [None]:
#@title Step 3: Create Model

set_seed(SEED)

# Build model config (update vocab_size from tokenizer)
model_config = ModelConfig(
    vocab_size=tokenizer.vocab_size,
    dim=DIM,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    n_kv_heads=N_KV_HEADS,
    max_seq_len=MAX_SEQ_LEN,
    hidden_dim=HIDDEN_DIM,
)
model_config.validate()

# Create model
model = LLaMA(model_config).to(device)
n_params = count_parameters(model)
print(f"\nModel parameters: {n_params:,}")
print_model_summary(model)

## Step 4: Train

The training loop below runs for `TRAINING_STEPS` optimizer steps with:
- **Mixed precision** (bf16 on Ampere+, fp16 on T4, fp32 on CPU)
- **Gradient accumulation** (effective batch = `BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS`)
- **Cosine LR schedule** with linear warmup
- **Gradient clipping** to prevent explosions
- **Periodic evaluation** on the validation set

In [None]:
#@title Step 4: Train the Model

# ── Setup ────────────────────────────────────────────────────────────────
dtype = get_dtype("auto", device)
autocast_ctx = get_autocast_context(device, dtype)
scaler = get_grad_scaler(device, dtype)
print(f"Training dtype: {dtype}")
print(f"GradScaler: {'enabled' if scaler else 'disabled'}")

# ── DataLoaders ──────────────────────────────────────────────────────────
train_loader = create_dataloader(
    train_bin, seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE,
    shuffle=True, pin_memory=(device.type != "cpu"),
)
val_loader = create_dataloader(
    val_bin, seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE,
    shuffle=False, pin_memory=(device.type != "cpu"),
)

# ── Optimizer ────────────────────────────────────────────────────────────
optimizer = model.configure_optimizers(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.95), device=device,
)

# ── Training Loop ────────────────────────────────────────────────────────
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model.train()
train_iter = iter(train_loader)
tokens_per_step = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * MAX_SEQ_LEN
best_val_loss = float("inf")
train_losses = []

print(f"\nStarting training: {TRAINING_STEPS} steps")
print(f"Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} sequences = {tokens_per_step:,} tokens/step")
print("=" * 70)

pbar = tqdm(range(TRAINING_STEPS), desc="Training", unit="step")
for step in pbar:
    step_start = time.perf_counter()

    # ── Learning Rate Schedule ───────────────────────────────────────────
    lr = get_lr(step, WARMUP_STEPS, TRAINING_STEPS, LEARNING_RATE, MIN_LEARNING_RATE)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # ── Gradient Accumulation ────────────────────────────────────────────
    accumulated_loss = 0.0
    for micro_step in range(GRADIENT_ACCUMULATION_STEPS):
        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with autocast_ctx:
            _, loss = model(x, targets=y)
            loss = loss / GRADIENT_ACCUMULATION_STEPS

        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        accumulated_loss += loss.item()

    # ── Gradient Clipping + Optimizer Step ───────────────────────────────
    if scaler is not None:
        scaler.unscale_(optimizer)
    nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

    if scaler is not None:
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    optimizer.zero_grad(set_to_none=True)

    # ── Timing ───────────────────────────────────────────────────────────
    if device.type == "cuda":
        torch.cuda.synchronize()
    step_time = time.perf_counter() - step_start
    tok_per_sec = tokens_per_step / step_time

    # ── Progress Bar ─────────────────────────────────────────────────────
    train_losses.append(accumulated_loss)
    pbar.set_postfix(loss=f"{accumulated_loss:.4f}", lr=f"{lr:.2e}", tps=f"{tok_per_sec:,.0f}")

    # ── Log ──────────────────────────────────────────────────────────────
    if step % LOG_INTERVAL == 0:
        mem = get_memory_usage(device)
        print(
            f"step {step:>5d}/{TRAINING_STEPS} | "
            f"loss {accumulated_loss:.4f} | "
            f"lr {lr:.2e} | "
            f"{tok_per_sec:>8,.0f} tok/s | "
            f"mem {mem['allocated_mb']:>6.0f} MB"
        )

    # ── Evaluation ───────────────────────────────────────────────────────
    if step > 0 and step % EVAL_INTERVAL == 0:
        val_loss = evaluate(model, val_loader, device, autocast_ctx, max_steps=EVAL_STEPS)
        perplexity = math.exp(val_loss)
        print(f"{'─' * 60}")
        print(f"EVAL step {step} | val_loss {val_loss:.4f} | perplexity {perplexity:.2f}")
        print(f"{'─' * 60}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                model, optimizer, step, val_loss,
                model_config.to_dict(), {},
                os.path.join(CHECKPOINT_DIR, "best.pt"),
            )
        model.train()

    # ── Periodic Checkpoint ──────────────────────────────────────────────
    if step > 0 and step % SAVE_INTERVAL == 0:
        save_checkpoint(
            model, optimizer, step, best_val_loss,
            model_config.to_dict(), {},
            os.path.join(CHECKPOINT_DIR, f"step_{step:06d}.pt"),
        )

# ── Final ────────────────────────────────────────────────────────────────
print("\nTraining complete!")
final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
save_checkpoint(
    model, optimizer, TRAINING_STEPS, best_val_loss,
    model_config.to_dict(), {}, final_path,
)
print(f"Final checkpoint saved: {final_path}")

In [None]:
#@title Step 5: Evaluate

val_loss = evaluate(model, val_loader, device, autocast_ctx, max_steps=EVAL_STEPS)
perplexity = math.exp(val_loss)

print(f"Final Validation Loss: {val_loss:.4f}")
print(f"Final Perplexity:      {perplexity:.2f}")
print(f"Best Validation Loss:  {best_val_loss:.4f}")
print(f"Best Perplexity:       {math.exp(best_val_loss):.2f}")

In [None]:
#@title Step 6: Generate Text

prompts = [
    "Once upon a time",
    "The little dog",
    "She looked at the sky and",
    "One day, a boy named Tom",
]

for temp in [0.7, 1.0]:
    print(f"\n{'=' * 60}")
    print(f"Temperature = {temp}")
    print(f"{'=' * 60}")
    for prompt in prompts:
        text = generate(
            model, tokenizer, prompt,
            max_new_tokens=150, temperature=temp,
            top_k=40, top_p=0.9, device=device,
        )
        print(f"\n--- Prompt: \"{prompt}\" ---")
        print(text)
    model.train()

In [None]:
#@title Step 7: Save & Download Model

# Save model config alongside checkpoint
model_config.save(os.path.join(CHECKPOINT_DIR, "model_config.json"))
print(f"Model config saved to {CHECKPOINT_DIR}model_config.json")

# List all checkpoints
print("\nCheckpoints:")
for f in sorted(os.listdir(CHECKPOINT_DIR)):
    path = os.path.join(CHECKPOINT_DIR, f)
    size_mb = os.path.getsize(path) / 1024**2
    print(f"  {f}: {size_mb:.1f} MB")

# Download the best checkpoint (Colab only)
try:
    from google.colab import files
    best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
    if os.path.exists(best_path):
        print(f"\nDownloading {best_path}...")
        files.download(best_path)
    else:
        print(f"\nDownloading final checkpoint...")
        files.download(os.path.join(CHECKPOINT_DIR, "final.pt"))
except ImportError:
    print("\nNot running on Colab — skipping download.")
    print(f"Checkpoints are at: {os.path.abspath(CHECKPOINT_DIR)}")

In [None]:
#@title Bonus: Load Checkpoint & Generate (proves the save works)

# Create a fresh model from config
loaded_config = ModelConfig.load(os.path.join(CHECKPOINT_DIR, "model_config.json"))
loaded_model = LLaMA(loaded_config).to(device)

# Load the best checkpoint
ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
if not os.path.exists(ckpt_path):
    ckpt_path = os.path.join(CHECKPOINT_DIR, "final.pt")

info = load_checkpoint(ckpt_path, loaded_model, device=device)
print(f"Loaded checkpoint from step {info['step']}, val_loss {info['val_loss']:.4f}")

# Generate with loaded model
print("\n--- Generation from loaded checkpoint ---")
for prompt in ["Once upon a time", "The little cat was"]:
    text = generate(
        loaded_model, tokenizer, prompt,
        max_new_tokens=100, temperature=0.8, device=device,
    )
    print(f"\nPrompt: \"{prompt}\"")
    print(text)

print("\nCheckpoint load + generate: SUCCESS")