# Module 4: Training Basics

In this notebook, we'll learn how to train a language model from scratch. We'll cover the essential components of a modern training pipeline.

## Learning Objectives

By the end of this notebook, you will:
1. Understand the data pipeline for language model training
2. Configure optimizers (AdamW) and learning rate schedules
3. Implement mixed precision training for efficiency
4. Build a basic training loop with gradient accumulation
5. Monitor training metrics and diagnose issues
6. Save and load checkpoints
7. Train a small model on real data

## What We'll Build

A complete training pipeline including:
- Custom dataset and dataloader
- Training loop with validation
- Learning rate scheduling (warmup + cosine decay)
- Gradient clipping and accumulation
- Checkpointing and resumption

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
import math

# Set seeds
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 6)

# Check device
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

## Part 1: Understanding the Data Pipeline

### Language Model Training Data

For autoregressive language models, we:
1. Tokenize text into integer sequences
2. Create sliding windows of length `max_seq_len`
3. Use input[:-1] as input, input[1:] as target (next token prediction)

Example:
```
Text: "The cat sat on the mat"
Tokens: [1, 45, 23, 67, 12, 45, 89]
Input:  [1, 45, 23, 67, 12, 45]  <- predict next token at each position
Target: [45, 23, 67, 12, 45, 89] <- shifted by 1
```

In [None]:
class TextDataset(Dataset):
    """
    Dataset for language model training.

    Creates sliding windows over tokenized text.
    """

    def __init__(
        self,
        data_path: str,
        max_seq_len: int = 256,
        stride: int = None,
    ):
        """
        Args:
            data_path: Path to tokenized data file (one token ID per line)
            max_seq_len: Maximum sequence length
            stride: Stride for sliding window (default: max_seq_len, no overlap)
        """
        self.max_seq_len = max_seq_len
        self.stride = stride or max_seq_len

        # Load tokenized data
        print(f"Loading data from {data_path}...")
        with open(data_path, "r") as f:
            # Assume data is space-separated token IDs
            self.tokens = [int(x) for line in f for x in line.strip().split()]

        print(f"Loaded {len(self.tokens):,} tokens")

        # Calculate number of sequences
        self.num_sequences = max(1, (len(self.tokens) - max_seq_len) // self.stride + 1)
        print(f"Created {self.num_sequences:,} sequences of length {max_seq_len}")

    def __len__(self):
        return self.num_sequences

    def __getitem__(self, idx):
        """
        Get a single training example.

        Returns:
            dict with 'input_ids', 'attention_mask', and 'labels'
        """
        # Get starting position
        start_idx = idx * self.stride
        end_idx = start_idx + self.max_seq_len

        # Extract sequence
        sequence = self.tokens[start_idx:end_idx]

        # Pad if necessary (for last sequence)
        if len(sequence) < self.max_seq_len:
            sequence = sequence + [0] * (self.max_seq_len - len(sequence))

        # Convert to tensor
        input_ids = torch.tensor(sequence, dtype=torch.long)

        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = (input_ids != 0).long()

        # Labels are same as input (we'll shift inside the model)
        labels = input_ids.clone()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


# Create a dummy dataset for demonstration
# In practice, you'd use real tokenized data
dummy_data_path = Path("dummy_tokens.txt")
with open(dummy_data_path, "w") as f:
    # Write random token IDs
    tokens = np.random.randint(1, 1000, size=10000)
    f.write(" ".join(map(str, tokens)))

# Create dataset
dataset = TextDataset(str(dummy_data_path), max_seq_len=128)

# Test it
sample = dataset[0]
print("\nSample batch:")
print(f"  input_ids shape: {sample['input_ids'].shape}")
print(f"  attention_mask shape: {sample['attention_mask'].shape}")
print(f"  labels shape: {sample['labels'].shape}")
print(f"  First 10 tokens: {sample['input_ids'][:10].tolist()}")

### Creating DataLoaders

In [None]:
# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

# Create dataloaders
batch_size = 8
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Effective batch size: {batch_size}")

# Test loading a batch
batch = next(iter(train_loader))
print("\nBatch shapes:")
for key, val in batch.items():
    print(f"  {key}: {val.shape}")

## Part 2: Building a Simple Model

For this tutorial, we'll use a small GPT-style model (from notebook 02).

In [None]:
# Simplified model for demonstration
class TinyGPT(nn.Module):
    """
    A tiny GPT-style model for training demonstration.
    """

    def __init__(
        self,
        vocab_size: int = 10000,
        d_model: int = 256,
        num_layers: int = 4,
        num_heads: int = 4,
        d_ff: int = 1024,
        max_seq_len: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks (simplified)
        self.blocks = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(
                    d_model=d_model,
                    nhead=num_heads,
                    dim_feedforward=d_ff,
                    dropout=dropout,
                    batch_first=True,
                )
                for _ in range(num_layers)
            ]
        )

        # Output
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        self.dropout = nn.Dropout(dropout)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize model weights."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass.

        Args:
            input_ids: Token indices (batch, seq_len)
            attention_mask: Attention mask (batch, seq_len)
            labels: Target tokens for loss computation (batch, seq_len)
        """
        batch_size, seq_len = input_ids.size()

        # Create position IDs
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)

        # Embed
        token_emb = self.token_embedding(input_ids)
        pos_emb = self.pos_embedding(positions)
        x = self.dropout(token_emb + pos_emb)

        # Create causal mask
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=input_ids.device) * float("-inf"),
            diagonal=1,
        )

        # Transformer blocks
        for block in self.blocks:
            x = block(x, src_mask=causal_mask)

        # Output
        x = self.ln_f(x)
        logits = self.lm_head(x)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift logits and labels for next token prediction
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            # Flatten and compute cross entropy
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=0,  # Ignore padding
            )

        return {"logits": logits, "loss": loss}


# Create model
model = TinyGPT(
    vocab_size=1000,
    d_model=256,
    num_layers=4,
    num_heads=4,
    d_ff=1024,
    max_seq_len=128,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model created!")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1e6:.1f} MB (float32)")

## Part 3: Optimizer Configuration

### AdamW: The Standard Choice

**AdamW** (Adam with decoupled weight decay) is the most common optimizer for transformers.

Key hyperparameters:
- **Learning rate**: 1e-4 to 3e-4 typical
- **Weight decay**: 0.01 to 0.1 (for regularization)
- **Betas**: (0.9, 0.95) or (0.9, 0.999)
- **Epsilon**: 1e-8

### Weight Decay Groups

We typically **don't** apply weight decay to:
- Layer norm parameters
- Biases
- Embeddings (sometimes)

In [None]:
def configure_optimizer(model, learning_rate=3e-4, weight_decay=0.1, betas=(0.9, 0.95)):
    """
    Configure AdamW optimizer with proper weight decay groups.
    """
    # Separate parameters into weight decay and no weight decay groups
    decay = set()
    no_decay = set()

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue

        # No weight decay for layer norms and biases
        if "ln" in name or "bias" in name or "norm" in name:
            no_decay.add(param)
        else:
            decay.add(param)

    # Create parameter groups
    param_groups = [
        {"params": list(decay), "weight_decay": weight_decay},
        {"params": list(no_decay), "weight_decay": 0.0},
    ]

    print("Optimizer groups:")
    print(f"  With weight decay: {len(decay)} params")
    print(f"  Without weight decay: {len(no_decay)} params")

    # Create optimizer
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=learning_rate,
        betas=betas,
        eps=1e-8,
    )

    return optimizer


optimizer = configure_optimizer(model, learning_rate=3e-4, weight_decay=0.1)

## Part 4: Learning Rate Scheduling

### Cosine Schedule with Warmup

Best practice for transformer training:

1. **Warmup** (first 5-10% of training):
   - Linearly increase LR from 0 to max_lr
   - Stabilizes training at the start

2. **Cosine Decay** (remaining training):
   - Smoothly decrease LR following cosine curve
   - Often to 0.1 Ã— max_lr (not zero)

$$\text{lr}(t) = \text{lr}_{\text{min}} + \frac{1}{2}(\text{lr}_{\text{max}} - \text{lr}_{\text{min}})\left(1 + \cos\left(\frac{t - t_{\text{warmup}}}{t_{\text{max}} - t_{\text{warmup}}}\pi\right)\right)$$

In [None]:
def get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps: int,
    num_training_steps: int,
    min_lr_ratio: float = 0.1,
):
    """
    Create a learning rate scheduler with linear warmup and cosine decay.

    Args:
        optimizer: The optimizer
        num_warmup_steps: Number of warmup steps
        num_training_steps: Total number of training steps
        min_lr_ratio: Minimum LR as a ratio of max LR
    """

    def lr_lambda(current_step):
        # Warmup
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        # Cosine decay
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps)
        )
        return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (
            1.0 + math.cos(math.pi * progress)
        )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


# Create scheduler
num_epochs = 3
num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)  # 10% warmup

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

# Visualize the schedule
lrs = []
for step in range(num_training_steps):
    lrs.append(scheduler.get_last_lr()[0])
    scheduler.step()

# Reset scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps
)

# Plot
plt.figure(figsize=(12, 5))
plt.plot(lrs, linewidth=2)
plt.axvline(
    x=num_warmup_steps,
    color="red",
    linestyle="--",
    label=f"End of warmup (step {num_warmup_steps})",
)
plt.xlabel("Training Step")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Schedule: Warmup + Cosine Decay")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Max LR: {max(lrs):.2e}")
print(f"Min LR: {min(lrs):.2e}")

## Part 5: Mixed Precision Training

### Why Mixed Precision?

- **Faster**: FP16/BF16 operations are 2-3x faster
- **Less memory**: Half the memory per parameter
- **Same accuracy**: With proper loss scaling

### FP16 vs BF16:

- **FP16**: Needs loss scaling, more compatible
- **BF16**: Better for training, no loss scaling needed, requires newer GPUs

We'll use PyTorch's `autocast` and `GradScaler`.

In [None]:
# Check if we can use mixed precision
use_amp = device.type in ["cuda", "mps"]
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

# Create gradient scaler (only for FP16)
scaler = GradScaler() if (use_amp and dtype == torch.float16) else None

print(f"Mixed precision training: {use_amp}")
if use_amp:
    print(f"  Using dtype: {dtype}")
    print(f"  Using GradScaler: {scaler is not None}")

## Part 6: Training Loop

Now let's implement a complete training loop with:
- Mixed precision
- Gradient accumulation
- Gradient clipping
- Progress tracking

In [None]:
def train_epoch(
    model,
    train_loader,
    optimizer,
    scheduler,
    device,
    scaler=None,
    gradient_accumulation_steps=1,
    max_grad_norm=1.0,
):
    """
    Train for one epoch.
    """
    model.train()
    total_loss = 0

    pbar = tqdm(train_loader, desc="Training")

    for step, batch in enumerate(pbar):
        # Move to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass with mixed precision
        with autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs["loss"] / gradient_accumulation_steps

        # Backward pass
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        # Gradient accumulation
        if (step + 1) % gradient_accumulation_steps == 0:
            # Gradient clipping
            if scaler is not None:
                scaler.unscale_(optimizer)

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            # Optimizer step
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()

            scheduler.step()
            optimizer.zero_grad()

        # Track metrics
        total_loss += loss.item() * gradient_accumulation_steps

        # Update progress bar
        pbar.set_postfix(
            {
                "loss": f"{loss.item() * gradient_accumulation_steps:.4f}",
                "lr": f"{scheduler.get_last_lr()[0]:.2e}",
            }
        )

    return total_loss / len(train_loader)


@torch.no_grad()
def evaluate(model, val_loader, device):
    """
    Evaluate on validation set.
    """
    model.eval()
    total_loss = 0

    for batch in tqdm(val_loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        with autocast(device_type=device.type, dtype=dtype, enabled=use_amp):
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

        total_loss += outputs["loss"].item()

    avg_loss = total_loss / len(val_loader)
    perplexity = math.exp(avg_loss)

    return {"loss": avg_loss, "perplexity": perplexity}

## Part 7: Run Training!

Let's train our model for a few epochs.

In [None]:
# Training configuration
num_epochs = 3
gradient_accumulation_steps = 4
max_grad_norm = 1.0

# Track metrics
train_losses = []
val_losses = []
val_perplexities = []

print(f"Starting training for {num_epochs} epochs...")
print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
print()

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Train
    train_loss = train_epoch(
        model,
        train_loader,
        optimizer,
        scheduler,
        device,
        scaler,
        gradient_accumulation_steps,
        max_grad_norm,
    )
    train_losses.append(train_loss)

    # Evaluate
    val_metrics = evaluate(model, val_loader, device)
    val_losses.append(val_metrics["loss"])
    val_perplexities.append(val_metrics["perplexity"])

    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Val Perplexity: {val_metrics['perplexity']:.2f}")
    print()

print("Training complete!")

### Visualize Training Progress

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot losses
epochs = range(1, num_epochs + 1)
ax1.plot(epochs, train_losses, "o-", label="Train Loss", linewidth=2)
ax1.plot(epochs, val_losses, "s-", label="Val Loss", linewidth=2)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.set_title("Training and Validation Loss")
ax1.legend()
ax1.grid(alpha=0.3)

# Plot perplexity
ax2.plot(epochs, val_perplexities, "o-", color="green", linewidth=2)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Perplexity")
ax2.set_title("Validation Perplexity")
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final validation perplexity: {val_perplexities[-1]:.2f}")
print("Lower is better! (Perfect model = 1.0, random model = vocab_size)")

## Part 8: Checkpointing

Always save checkpoints during training!

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, loss, path):
    """
    Save a training checkpoint.
    """
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss,
    }

    torch.save(checkpoint, path)
    print(f"Checkpoint saved to {path}")


def load_checkpoint(model, optimizer, scheduler, path):
    """
    Load a training checkpoint.
    """
    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    print(f"Checkpoint loaded from {path}")
    print(f"  Resuming from epoch {checkpoint['epoch']}")
    print(f"  Previous loss: {checkpoint['loss']:.4f}")

    return checkpoint["epoch"]


# Save a checkpoint
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

save_checkpoint(
    model,
    optimizer,
    scheduler,
    epoch=num_epochs,
    loss=val_losses[-1],
    path=checkpoint_dir / "model_final.pt",
)

## Summary and Key Takeaways

In this notebook, you learned:

1. **Data Pipeline**:
   - Creating sliding windows over text
   - Proper batching and data loading
   - Input/target pairs for next token prediction

2. **Optimizer Setup**:
   - AdamW with weight decay
   - Separating parameters into decay groups
   - Typical hyperparameters

3. **Learning Rate Scheduling**:
   - Warmup for stability
   - Cosine decay for smooth convergence
   - Typical warmup ratios (5-10%)

4. **Mixed Precision Training**:
   - FP16/BF16 for efficiency
   - Using autocast and GradScaler
   - 2-3x speedup with same accuracy

5. **Training Loop**:
   - Gradient accumulation for larger effective batch sizes
   - Gradient clipping for stability
   - Progress tracking and validation

6. **Checkpointing**:
   - Saving model state
   - Resuming training

### Key Metrics:

- **Loss**: Cross-entropy loss (lower is better)
- **Perplexity**: exp(loss), interpretable as "branching factor" (lower is better)
- **Learning Rate**: Monitor to ensure proper scheduling

### What's Next?

In the final notebook, we'll:
- Train a complete MoE model on real data
- Integrate MLflow for experiment tracking
- Implement text generation (inference)
- Generate stories with our trained model!

### Further Reading

- [Mixed Precision Training](https://arxiv.org/abs/1710.03740) (Micikevicius et al., 2017)
- [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) (Loshchilov & Hutter, 2017)
- [Accurate, Large Minibatch SGD](https://arxiv.org/abs/1706.02677) (Goyal et al., 2017)

## Exercise: Tune Hyperparameters

Try different hyperparameters and observe the effects:

1. **Learning rate**: Try 1e-4, 3e-4, 1e-3
2. **Warmup ratio**: Try 0.05, 0.1, 0.2
3. **Weight decay**: Try 0, 0.01, 0.1
4. **Gradient accumulation**: Try 1, 2, 4, 8

Questions to explore:
- How does learning rate affect convergence speed?
- What happens without warmup?
- How does weight decay affect generalization (train vs val loss)?
- What's the effect of larger effective batch sizes?