# LLaMA from Scratch: Train & Generate on Google Colab

This notebook trains a **~15M parameter LLaMA model** from scratch in pure PyTorch, end-to-end:

1. **Choose a dataset** (TinyStories, GSM8K, SimpleMath, AQUA-RAT, or mixed)
2. **Train a BPE tokenizer** (SentencePiece, covering all selected datasets)
3. **Tokenize the dataset** into memory-mapped binary files
4. **Train the model** with mixed precision, gradient accumulation, cosine LR schedule
5. **Evaluate** on the validation set (loss + perplexity)
6. **Generate text** from prompts with temperature/top-k/top-p sampling
7. **Save & download** the trained checkpoint

**Datasets:**
| Name | Description | Size |
|------|-------------|------|
| `tinystories` | Short children's stories (default) | ~2.1M examples |
| `gsm8k` | Grade school math word problems | 8.5K examples |
| `simplemath` | Basic arithmetic problems | 100K examples |
| `aqua_rat` | Word problems with reasoning | 98K examples |
| `mixed` | All of the above combined | — |

**Architecture:** Decoder-only transformer following LLaMA (Meta AI)
- RMSNorm (pre-normalization)
- Rotary Positional Embeddings (RoPE)
- SwiGLU activation in FFN
- Grouped Query Attention (GQA, 6 query heads / 2 KV heads)
- KV cache for efficient inference

**Requirements:** A Colab GPU runtime (T4 or A100). Go to *Runtime > Change runtime type > GPU*.

In [None]:
#@title Setup: Install Dependencies & Clone Repo
!pip install -q torch sentencepiece datasets tqdm
!git clone -q https://github.com/manojkgorle/smol-llama2.git
%cd smol-llama2

In [None]:
#@title Configuration — Edit these parameters!

# ── Dataset Selection ────────────────────────────────────────────────────
# Choose which dataset to train on:
#   "tinystories"  — ~2.1M short stories (default)
#   "gsm8k"        — 8.5K grade school math word problems
#   "simplemath"   — 100K basic arithmetic problems
#   "aqua_rat"     — 98K word problems with reasoning
#   "mixed"        — All of the above combined
DATASET = "tinystories"

# Which datasets to include in tokenizer training.
# The tokenizer should cover ALL text the model will ever see.
# Use ["tinystories"] for stories only, or list all datasets you plan to use.
# Ignored if DATASET != "mixed" (auto-set to just [DATASET] for single datasets).
TOKENIZER_DATASETS = ["tinystories", "gsm8k", "simplemath", "aqua_rat"]

# ── Training ─────────────────────────────────────────────────────────────
TRAINING_STEPS = 3000          # Total optimizer steps (~15 min on T4)
BATCH_SIZE = 64                # Sequences per micro-batch
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch = 64 * 4 = 256 sequences
LEARNING_RATE = 3e-4           # Peak LR (after warmup)
MIN_LEARNING_RATE = 3e-5       # Floor LR (10% of peak)
WARMUP_STEPS = 200             # Linear warmup steps
WEIGHT_DECAY = 0.1             # AdamW weight decay
MAX_GRAD_NORM = 1.0            # Gradient clipping

# ── Evaluation & Logging ─────────────────────────────────────────────────
EVAL_INTERVAL = 500            # Evaluate every N steps
EVAL_STEPS = 20                # Batches per evaluation
LOG_INTERVAL = 50              # Print loss every N steps
SAVE_INTERVAL = 1000           # Save checkpoint every N steps

# ── Model Architecture ──────────────────────────────────────────────────
VOCAB_SIZE = 4096
DIM = 384
N_LAYERS = 8
N_HEADS = 6
N_KV_HEADS = 2
MAX_SEQ_LEN = 512
HIDDEN_DIM = 1024

# ── Paths ────────────────────────────────────────────────────────────────
DATA_DIR = "data/"
CHECKPOINT_DIR = "checkpoints/"
TOKENIZER_PREFIX = "data/tokenizer"  # produces data/tokenizer.model
SEED = 42

# ── Auto-configure tokenizer datasets for single-dataset mode ────────────
if DATASET != "mixed":
    TOKENIZER_DATASETS = [DATASET]

print(f"Dataset: {DATASET}")
print(f"Tokenizer trained on: {TOKENIZER_DATASETS}")
print(f"Training for {TRAINING_STEPS} steps")
print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} sequences")
print(f"Tokens per step: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * MAX_SEQ_LEN:,}")

In [None]:
#@title Device Detection
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    props = torch.cuda.get_device_properties(device)
    print(f"GPU: {props.name}")
    print(f"VRAM: {props.total_mem / 1024**3:.1f} GB")
    print(f"Compute Capability: {props.major}.{props.minor}")
    print(f"BF16 Support: {torch.cuda.is_bf16_supported()}")
    print(f"CUDA Version: {torch.version.cuda}")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Device: MPS (Apple Silicon)")
else:
    device = torch.device("cpu")
    print("WARNING: No GPU detected! Training will be very slow.")
    print("Go to Runtime > Change runtime type > GPU")

print(f"\nDevice: {device}")
print(f"PyTorch: {torch.__version__}")

In [None]:
#@title Import Modules
import os
import math
import time

import torch.nn as nn
from tqdm.auto import tqdm

from llama_vc.config import ModelConfig
from llama_vc.model import LLaMA
from llama_vc.tokenizer import Tokenizer, train_tokenizer
from llama_vc.dataset import (
    download_dataset, prepare_dataset, prepare_mixed,
    create_dataloader, DATASETS, DATASET_NAMES,
)
from llama_vc.device import (
    get_dtype, get_autocast_context,
    get_grad_scaler, get_memory_usage,
)
from llama_vc.generate import generate
from llama_vc.train import get_lr, evaluate
from llama_vc.utils import (
    set_seed, count_parameters, print_model_summary,
    save_checkpoint, load_checkpoint,
)

print("All modules imported successfully!")

## Step 1: Train Tokenizer

We train a **BPE tokenizer** (Byte-Pair Encoding) with SentencePiece on text from all selected datasets (`TOKENIZER_DATASETS`).

- **Vocab size:** 4096 tokens (small, matching our tiny model)
- **Byte fallback:** Unknown characters are encoded as UTF-8 bytes (no `<unk>` tokens)
- **Special tokens:** `<s>` (BOS, id=1), `</s>` (EOS, id=2)
- **Multi-dataset:** SentencePiece accepts comma-separated input files, so the vocabulary covers all selected text sources

In [None]:
#@title Step 1: Train Tokenizer

# Download text files for each dataset the tokenizer needs to cover
text_files = []
for ds_name in TOKENIZER_DATASETS:
    print(f"--- Downloading {ds_name} ---")
    train_txt, _ = download_dataset(ds_name, DATA_DIR)
    text_files.append(train_txt)

# Train the tokenizer on all text sources
# SentencePiece natively accepts comma-separated input files
tokenizer_model_path = TOKENIZER_PREFIX + ".model"
if os.path.exists(tokenizer_model_path):
    print(f"\nTokenizer already exists: {tokenizer_model_path}")
else:
    combined_input = ",".join(text_files)
    print(f"\nTraining tokenizer on {len(text_files)} source(s):")
    for f in text_files:
        size_mb = os.path.getsize(f) / 1024**2
        print(f"  {f} ({size_mb:.1f} MB)")
    train_tokenizer(
        input_file=combined_input,
        model_prefix=TOKENIZER_PREFIX,
        vocab_size=VOCAB_SIZE,
    )

# Load and verify
tokenizer = Tokenizer(tokenizer_model_path)
print(f"\nTokenizer loaded: vocab_size={tokenizer.vocab_size}")

# Test roundtrip on story + math text
test_texts = [
    "Once upon a time, there was a little cat.",
    "Question: What is 48 / 2?\nAnswer: 24",
]
for text in test_texts:
    tokens = tokenizer.encode(text, bos=False, eos=False)
    decoded = tokenizer.decode(tokens)
    status = "PASS" if decoded == text else "FAIL"
    print(f"[{status}] Roundtrip: '{text[:50]}...' ({len(tokens)} tokens)"
          if len(text) > 50 else
          f"[{status}] Roundtrip: '{text}' ({len(tokens)} tokens)")
print("Tokenizer ready!")

In [None]:
#@title Step 2: Prepare Data (Tokenize to Binary)

# Prepare the selected dataset (download + tokenize to .bin)
if DATASET == "mixed":
    train_bin, val_bin = prepare_mixed(DATA_DIR, tokenizer)
else:
    train_bin, val_bin = prepare_dataset(DATASET, DATA_DIR, tokenizer)

# Validate — remove empty/corrupt files so re-run will regenerate them
for bin_path in [train_bin, val_bin]:
    if os.path.exists(bin_path) and os.path.getsize(bin_path) == 0:
        print(f"WARNING: {bin_path} is empty, removing so it will be regenerated.")
        os.remove(bin_path)

# Re-run if needed
if not os.path.exists(train_bin) or not os.path.exists(val_bin):
    if DATASET == "mixed":
        train_bin, val_bin = prepare_mixed(DATA_DIR, tokenizer)
    else:
        train_bin, val_bin = prepare_dataset(DATASET, DATA_DIR, tokenizer)

# Print dataset stats
train_tokens = os.path.getsize(train_bin) // 2  # uint16 = 2 bytes
val_tokens = os.path.getsize(val_bin) // 2
print(f"\nDataset: {DATASET}")
print(f"Train tokens: {train_tokens:,}")
print(f"Val tokens:   {val_tokens:,}")
print(f"Train file:   {os.path.getsize(train_bin) / 1024**2:.1f} MB")
print(f"Val file:     {os.path.getsize(val_bin) / 1024**2:.1f} MB")

assert train_tokens > 0, "train.bin is empty — data preparation failed!"
assert val_tokens > 0, "val.bin is empty — data preparation failed!"
print("Data validation: PASSED")

In [None]:
#@title Step 3: Create Model

set_seed(SEED)

# Build model config (update vocab_size from tokenizer)
model_config = ModelConfig(
    vocab_size=tokenizer.vocab_size,
    dim=DIM,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    n_kv_heads=N_KV_HEADS,
    max_seq_len=MAX_SEQ_LEN,
    hidden_dim=HIDDEN_DIM,
)
model_config.validate()

# Create model
model = LLaMA(model_config).to(device)
n_params = count_parameters(model)
print(f"\nModel parameters: {n_params:,}")
print_model_summary(model)

## Step 4: Train

The training loop below runs for `TRAINING_STEPS` optimizer steps with:
- **Mixed precision** (bf16 on Ampere+, fp16 on T4, fp32 on CPU)
- **Gradient accumulation** (effective batch = `BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS`)
- **Cosine LR schedule** with linear warmup
- **Gradient clipping** to prevent explosions
- **Periodic evaluation** on the validation set

In [None]:
#@title Step 4: Train the Model

# ── Setup ────────────────────────────────────────────────────────────────
dtype = get_dtype("auto", device)
autocast_ctx = get_autocast_context(device, dtype)
scaler = get_grad_scaler(device, dtype)
print(f"Training dtype: {dtype}")
print(f"GradScaler: {'enabled' if scaler else 'disabled'}")

# ── DataLoaders ──────────────────────────────────────────────────────────
train_loader = create_dataloader(
    train_bin, seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE,
    shuffle=True, pin_memory=(device.type != "cpu"),
)
val_loader = create_dataloader(
    val_bin, seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE,
    shuffle=False, pin_memory=(device.type != "cpu"),
)

# ── Optimizer ────────────────────────────────────────────────────────────
optimizer = model.configure_optimizers(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
    betas=(0.9, 0.95), device=device,
)

# ── Training Loop ────────────────────────────────────────────────────────
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

model.train()
train_iter = iter(train_loader)
tokens_per_step = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * MAX_SEQ_LEN
best_val_loss = float("inf")
train_losses = []

print(f"\nStarting training: {TRAINING_STEPS} steps")
print(f"Effective batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} sequences = {tokens_per_step:,} tokens/step")
print("=" * 70)

pbar = tqdm(range(TRAINING_STEPS), desc="Training", unit="step")
for step in pbar:
    step_start = time.perf_counter()

    # ── Learning Rate Schedule ───────────────────────────────────────────
    lr = get_lr(step, WARMUP_STEPS, TRAINING_STEPS, LEARNING_RATE, MIN_LEARNING_RATE)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # ── Gradient Accumulation ────────────────────────────────────────────
    accumulated_loss = 0.0
    for micro_step in range(GRADIENT_ACCUMULATION_STEPS):
        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)

        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with autocast_ctx:
            _, loss = model(x, targets=y)
            loss = loss / GRADIENT_ACCUMULATION_STEPS

        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        accumulated_loss += loss.item()

    # ── Gradient Clipping + Optimizer Step ───────────────────────────────
    if scaler is not None:
        scaler.unscale_(optimizer)
    nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)

    if scaler is not None:
        scaler.step(optimizer)
        scaler.update()
    else:
        optimizer.step()

    optimizer.zero_grad(set_to_none=True)

    # ── Timing ───────────────────────────────────────────────────────────
    if device.type == "cuda":
        torch.cuda.synchronize()
    step_time = time.perf_counter() - step_start
    tok_per_sec = tokens_per_step / step_time

    # ── Progress Bar ─────────────────────────────────────────────────────
    train_losses.append(accumulated_loss)
    pbar.set_postfix(loss=f"{accumulated_loss:.4f}", lr=f"{lr:.2e}", tps=f"{tok_per_sec:,.0f}")

    # ── Log ──────────────────────────────────────────────────────────────
    if step % LOG_INTERVAL == 0:
        mem = get_memory_usage(device)
        print(
            f"step {step:>5d}/{TRAINING_STEPS} | "
            f"loss {accumulated_loss:.4f} | "
            f"lr {lr:.2e} | "
            f"{tok_per_sec:>8,.0f} tok/s | "
            f"mem {mem['allocated_mb']:>6.0f} MB"
        )

    # ── Evaluation ───────────────────────────────────────────────────────
    if step > 0 and step % EVAL_INTERVAL == 0:
        val_loss = evaluate(model, val_loader, device, autocast_ctx, max_steps=EVAL_STEPS)
        perplexity = math.exp(val_loss)
        print(f"{'─' * 60}")
        print(f"EVAL step {step} | val_loss {val_loss:.4f} | perplexity {perplexity:.2f}")
        print(f"{'─' * 60}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint(
                model, optimizer, step, val_loss,
                model_config.to_dict(), {},
                os.path.join(CHECKPOINT_DIR, "best.pt"),
            )
        model.train()

    # ── Periodic Checkpoint ──────────────────────────────────────────────
    if step > 0 and step % SAVE_INTERVAL == 0:
        save_checkpoint(
            model, optimizer, step, best_val_loss,
            model_config.to_dict(), {},
            os.path.join(CHECKPOINT_DIR, f"step_{step:06d}.pt"),
        )

# ── Final evaluation + checkpoint ────────────────────────────────────────
print("\nTraining complete! Running final evaluation...")
val_loss = evaluate(model, val_loader, device, autocast_ctx, max_steps=EVAL_STEPS)
perplexity = math.exp(val_loss)
print(f"Final val_loss: {val_loss:.4f} | perplexity: {perplexity:.2f}")

if val_loss < best_val_loss:
    best_val_loss = val_loss
    save_checkpoint(
        model, optimizer, TRAINING_STEPS, val_loss,
        model_config.to_dict(), {},
        os.path.join(CHECKPOINT_DIR, "best.pt"),
    )
    print(f"New best checkpoint saved at step {TRAINING_STEPS}!")

final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
save_checkpoint(
    model, optimizer, TRAINING_STEPS, best_val_loss,
    model_config.to_dict(), {}, final_path,
)
print(f"Final checkpoint saved: {final_path}")

In [None]:
#@title Step 5: Evaluate

val_loss = evaluate(model, val_loader, device, autocast_ctx, max_steps=EVAL_STEPS)
perplexity = math.exp(val_loss)

print(f"Final Validation Loss: {val_loss:.4f}")
print(f"Final Perplexity:      {perplexity:.2f}")
print(f"Best Validation Loss:  {best_val_loss:.4f}")
print(f"Best Perplexity:       {math.exp(best_val_loss):.2f}")

In [None]:
#@title Step 6: Generate Text

# Dataset-appropriate prompts
PROMPTS = {
    "tinystories": [
        "Once upon a time",
        "The little dog",
        "She looked at the sky and",
        "One day, a boy named Tom",
    ],
    "gsm8k": [
        "Question: Sarah has 24 apples and gives half to her friend. How many does she have left?\nAnswer:",
        "Question: A train travels 60 miles per hour for 3 hours. How far does it go?\nAnswer:",
    ],
    "simplemath": [
        "245 + 378\nAnswer:",
        "1000 - 457\nAnswer:",
    ],
    "aqua_rat": [
        "Question: If a shirt costs $40 and is on sale for 25% off, what is the sale price?\nOptions:",
        "Question: A car travels 180 miles in 3 hours. What is its average speed?\nOptions:",
    ],
}

# Use matching prompts, or story prompts for mixed
prompts = PROMPTS.get(DATASET, PROMPTS["tinystories"])

for temp in [0.7, 1.0]:
    print(f"\n{'=' * 60}")
    print(f"Temperature = {temp}")
    print(f"{'=' * 60}")
    for prompt in prompts:
        result = generate(
            model, tokenizer, prompt,
            max_new_tokens=150, temperature=temp,
            top_k=40, top_p=0.9, device=device,
        )
        print(f"\n--- Prompt: \"{prompt[:60]}{'...' if len(prompt) > 60 else ''}\" ---")
        print(result.text)
        print(result.stats_string())
    model.train()

In [None]:
#@title Step 7: Save & Download Model

import zipfile

# Save model config alongside checkpoint
config_path = os.path.join(CHECKPOINT_DIR, "model_config.json")
model_config.save(config_path)
print(f"Model config saved to {config_path}")

# Determine which checkpoint to use as primary
best_path = os.path.join(CHECKPOINT_DIR, "best.pt")
final_path = os.path.join(CHECKPOINT_DIR, "final.pt")
primary_ckpt = best_path if os.path.exists(best_path) else final_path

# List all checkpoints
print("\nCheckpoints:")
for f in sorted(os.listdir(CHECKPOINT_DIR)):
    path = os.path.join(CHECKPOINT_DIR, f)
    size_mb = os.path.getsize(path) / 1024**2
    print(f"  {f}: {size_mb:.1f} MB")

# Bundle all files needed for inference into a single zip:
#   - Model weights (best.pt or final.pt)
#   - Tokenizer model (tokenizer.model)
#   - Model config (model_config.json)
tokenizer_path = TOKENIZER_PREFIX + ".model"

artifacts = {
    "checkpoint": primary_ckpt,
    "tokenizer": tokenizer_path,
    "config": config_path,
}

# Verify all files exist
print("\nInference artifacts:")
for name, path in artifacts.items():
    exists = os.path.exists(path)
    size_mb = os.path.getsize(path) / 1024**2 if exists else 0
    status = f"{size_mb:.1f} MB" if exists else "MISSING"
    print(f"  {name:12s}: {path} ({status})")

zip_path = "llama_vc_model.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
    for name, path in artifacts.items():
        if os.path.exists(path):
            zf.write(path, os.path.basename(path))

zip_size_mb = os.path.getsize(zip_path) / 1024**2
print(f"\nBundled into: {zip_path} ({zip_size_mb:.1f} MB)")

# Download (Colab only)
try:
    from google.colab import files
    print("Downloading zip...")
    files.download(zip_path)
except ImportError:
    print("Not running on Colab — skipping download.")
    print(f"Zip is at: {os.path.abspath(zip_path)}")

In [None]:
#@title Bonus: Load Checkpoint & Generate (proves the save works)

# Create a fresh model from config
loaded_config = ModelConfig.load(os.path.join(CHECKPOINT_DIR, "model_config.json"))
loaded_model = LLaMA(loaded_config).to(device)

# Load the best checkpoint
ckpt_path = os.path.join(CHECKPOINT_DIR, "best.pt")
if not os.path.exists(ckpt_path):
    ckpt_path = os.path.join(CHECKPOINT_DIR, "final.pt")

info = load_checkpoint(ckpt_path, loaded_model, device=device)
print(f"Loaded checkpoint from step {info['step']}, val_loss {info['val_loss']:.4f}")

# Generate with loaded model
print("\n--- Generation from loaded checkpoint ---")
for prompt in ["Once upon a time", "The little cat was"]:
    result = generate(
        loaded_model, tokenizer, prompt,
        max_new_tokens=100, temperature=0.8, device=device,
    )
    print(f"\nPrompt: \"{prompt}\"")
    print(result.text)
    print(result.stats_string())

print("\nCheckpoint load + generate: SUCCESS")