# Training Transformers: From Random Weights to Language Understanding

How does a transformer learn to understand and generate language? In this notebook, we'll explore the training process step by step, from loss functions to optimization strategies.

## What You'll Learn

1. **Language Modeling Objective** - Next token prediction and cross-entropy loss
2. **Training Loop** - Forward pass, backward pass, optimization
3. **Monitoring Progress** - Loss curves, perplexity, and sample generation
4. **Training Techniques** - Learning rates, warmup, gradient clipping
5. **Overfitting & Regularization** - Dropout, weight decay, early stopping
6. **Scaling Laws** - How performance scales with data and model size

Let's train a transformer from scratch!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Dict
import time
from tqdm import tqdm
import math

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 1. Language Modeling Objective

Language models are trained to predict the next token given previous tokens. This simple objective forces the model to learn grammar, semantics, and even reasoning!

Given a sequence $[x_1, x_2, ..., x_n]$, we want to maximize:
$$P(x_2|x_1) \cdot P(x_3|x_1, x_2) \cdot ... \cdot P(x_n|x_1, ..., x_{n-1})$$

In practice, we use cross-entropy loss for each position.

In [None]:
def demonstrate_language_modeling_objective():
    """Show how the language modeling objective works."""
    
    # Example sentence: "The cat sat"
    sentence = "The cat sat"
    tokens = sentence.split()
    
    print("Language Modeling: Predicting Next Token")
    print("=" * 45)
    
    print("Training examples created from 'The cat sat':")
    print()
    print("Input → Target")
    print("-" * 20)
    
    # Show how we create input-target pairs
    for i in range(len(tokens)):
        input_seq = tokens[:i+1] if i > 0 else ["<START>"]
        target = tokens[i] if i < len(tokens) else "<END>"
        
        input_str = " ".join(input_seq)
        print(f"{input_str:<12} → {target}")
    
    # Demonstrate with actual tensors
    vocab = {"<PAD>": 0, "<START>": 1, "The": 2, "cat": 3, "sat": 4, "<END>": 5}
    vocab_size = len(vocab)
    
    # Convert to IDs
    sequence = [vocab["<START>"], vocab["The"], vocab["cat"], vocab["sat"], vocab["<END>"]]
    
    print(f"\nTokenized sequence: {sequence}")
    print(f"Vocabulary: {vocab}")
    
    # Create input and target tensors
    input_ids = torch.tensor(sequence[:-1])  # All except last
    target_ids = torch.tensor(sequence[1:])  # All except first
    
    print(f"\nInput IDs:  {input_ids.tolist()}")
    print(f"Target IDs: {target_ids.tolist()}")
    
    # Simulate model predictions (random for demonstration)
    seq_len = len(input_ids)
    fake_logits = torch.randn(seq_len, vocab_size)  # [seq_len, vocab_size]
    
    # Calculate loss
    loss = F.cross_entropy(fake_logits, target_ids)
    
    print(f"\nModel logits shape: {fake_logits.shape}")
    print(f"Cross-entropy loss: {loss.item():.3f}")
    
    # Show probabilities for first prediction
    first_probs = F.softmax(fake_logits[0], dim=0)
    print(f"\nPredicted probabilities for first position:")
    for word, idx in vocab.items():
        prob = first_probs[idx].item()
        marker = " ← TARGET" if idx == target_ids[0] else ""
        print(f"  {word:<8}: {prob:.3f}{marker}")
    
    return input_ids, target_ids, fake_logits, loss

input_ids, target_ids, logits, loss = demonstrate_language_modeling_objective()

print("\nKey Insights:")
print("• Each position predicts the next token")
print("• Loss is calculated for all positions simultaneously")
print("• Model learns patterns by minimizing prediction errors")
print("• Same architecture can learn grammar, facts, reasoning!")

## 2. Training Loop Implementation

Let's implement a complete training loop and see how a transformer learns step by step.

In [None]:
from src.model.transformer import GPTModel, create_model_config
from src.data.tokenizer import create_tokenizer
from src.data.dataset import SimpleTextDataset, create_dataloader

class TrainingTracker:
    """Track training metrics and visualizations."""
    
    def __init__(self):
        self.losses = []
        self.learning_rates = []
        self.steps = []
        self.samples = []
        
    def log(self, step: int, loss: float, lr: float, sample: str = None):
        self.steps.append(step)
        self.losses.append(loss)
        self.learning_rates.append(lr)
        if sample:
            self.samples.append((step, sample))
    
    def plot_progress(self):
        """Plot training progress."""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        # Loss curve
        axes[0].plot(self.steps, self.losses, 'b-', alpha=0.7)
        axes[0].set_xlabel('Training Step')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, alpha=0.3)
        
        # Learning rate
        axes[1].plot(self.steps, self.learning_rates, 'r-', alpha=0.7)
        axes[1].set_xlabel('Training Step')
        axes[1].set_ylabel('Learning Rate')
        axes[1].set_title('Learning Rate Schedule')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

def train_step(model, batch, optimizer, criterion, device):
    """Single training step."""
    model.train()
    
    input_ids, target_ids = batch
    input_ids = input_ids.to(device)
    target_ids = target_ids.to(device)
    
    # Forward pass
    optimizer.zero_grad()
    logits, _ = model(input_ids)
    
    # Calculate loss
    batch_size, seq_len, vocab_size = logits.shape
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = target_ids.view(-1)
    
    loss = criterion(logits_flat, targets_flat)
    
    # Backward pass
    loss.backward()
    
    # Gradient clipping (prevent exploding gradients)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    return loss.item()

def generate_sample(model, tokenizer, prompt="The", max_length=20, temperature=0.8):
    """Generate a sample during training."""
    model.eval()
    
    with torch.no_grad():
        # Encode prompt
        tokens = tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.tensor(tokens).unsqueeze(0)
        
        # Generate
        generated = model.generate(input_ids, max_new_tokens=max_length, temperature=temperature)
        
        # Decode
        generated_text = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
        
    return generated_text

# Create training setup
def setup_training():
    """Setup model, data, and optimizer for training."""
    
    # Model configuration
    config = create_model_config("tiny")
    config["vocab_size"] = 200  # Reduce for faster training
    model = GPTModel(**config)
    
    # Data
    text = """The quick brown fox jumps over the lazy dog. The cat sat on the mat. 
    A bird in the hand is worth two in the bush. Time flies like an arrow. 
    The early bird catches the worm. Actions speak louder than words.
    The pen is mightier than the sword. All that glitters is not gold.
    Rome was not built in a day. The grass is always greener on the other side."""
    
    tokenizer = create_tokenizer("simple")
    dataset = SimpleTextDataset(text, tokenizer, block_size=32)
    dataloader = create_dataloader(dataset, batch_size=4, shuffle=True)
    
    # Optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Training samples: {len(dataset)}")
    print(f"Vocabulary size: {tokenizer.vocab_size}")
    print(f"Device: {device}")
    
    return model, dataloader, optimizer, criterion, tokenizer, device

# Setup training
model, dataloader, optimizer, criterion, tokenizer, device = setup_training()

# Test initial generation (before training)
print("\nBefore training:")
initial_sample = generate_sample(model, tokenizer, "The cat")
print(f"Generated: '{initial_sample}'")

## 3. Training the Model

Now let's train the model and watch it learn! We'll monitor the loss and see how text generation improves.

In [None]:
def train_model(model, dataloader, optimizer, criterion, tokenizer, device, num_epochs=5):
    """Train the model and track progress."""
    
    tracker = TrainingTracker()
    step = 0
    
    print(f"Training for {num_epochs} epochs...")
    print("=" * 50)
    
    for epoch in range(num_epochs):
        epoch_losses = []
        
        # Progress bar for epoch
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch in pbar:
            # Training step
            loss = train_step(model, batch, optimizer, criterion, device)
            epoch_losses.append(loss)
            
            # Log progress
            current_lr = optimizer.param_groups[0]['lr']
            tracker.log(step, loss, current_lr)
            
            # Update progress bar
            pbar.set_postfix({'loss': f'{loss:.3f}', 'lr': f'{current_lr:.2e}'})
            
            step += 1
        
        # Epoch summary
        avg_loss = np.mean(epoch_losses)
        print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.3f}")
        
        # Generate sample
        sample = generate_sample(model, tokenizer, "The cat", max_length=15)
        print(f"Sample: '{sample}'")
        tracker.log(step, avg_loss, current_lr, sample)
        print()
    
    return tracker

# Train the model
tracker = train_model(model, dataloader, optimizer, criterion, tokenizer, device, num_epochs=5)

# Plot training progress
print("Training completed! Here's the progress:")
tracker.plot_progress()

# Test final generation
print("\nAfter training:")
final_sample = generate_sample(model, tokenizer, "The cat", max_length=20)
print(f"Generated: '{final_sample}'")

# Compare before and after
print(f"\nImprovement:")
print(f"Before:  '{initial_sample}'")
print(f"After:   '{final_sample}'")

## 4. Understanding Loss and Perplexity

Loss tells us how well the model is learning, but perplexity is more interpretable. Perplexity roughly corresponds to "how many choices the model thinks it has at each step."

In [None]:
def analyze_loss_and_perplexity(tracker):
    """Analyze training loss and compute perplexity."""
    
    # Convert loss to perplexity
    perplexities = [math.exp(loss) for loss in tracker.losses]
    
    print("Loss and Perplexity Analysis")
    print("=" * 40)
    
    # Show progression
    print("Training progression:")
    print("Step | Loss  | Perplexity | Interpretation")
    print("-" * 50)
    
    checkpoints = [0, len(tracker.losses)//4, len(tracker.losses)//2, 
                   3*len(tracker.losses)//4, len(tracker.losses)-1]
    
    for i in checkpoints:
        loss = tracker.losses[i]
        perplexity = perplexities[i]
        
        if perplexity > 100:
            interpretation = "Very confused"
        elif perplexity > 50:
            interpretation = "Quite confused"
        elif perplexity > 20:
            interpretation = "Somewhat confused"
        elif perplexity > 10:
            interpretation = "Getting better"
        else:
            interpretation = "Fairly confident"
        
        print(f"{tracker.steps[i]:4} | {loss:5.2f} | {perplexity:10.1f} | {interpretation}")
    
    # Plot loss and perplexity
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss
    axes[0].plot(tracker.steps, tracker.losses, 'b-', alpha=0.7)
    axes[0].set_xlabel('Training Step')
    axes[0].set_ylabel('Cross-Entropy Loss')
    axes[0].set_title('Training Loss')
    axes[0].grid(True, alpha=0.3)
    
    # Perplexity
    axes[1].plot(tracker.steps, perplexities, 'r-', alpha=0.7)
    axes[1].set_xlabel('Training Step')
    axes[1].set_ylabel('Perplexity')
    axes[1].set_title('Perplexity (exp(loss))')
    axes[1].set_yscale('log')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nFinal metrics:")
    print(f"Initial loss: {tracker.losses[0]:.3f}, perplexity: {perplexities[0]:.1f}")
    print(f"Final loss: {tracker.losses[-1]:.3f}, perplexity: {perplexities[-1]:.1f}")
    print(f"Improvement: {tracker.losses[0] - tracker.losses[-1]:.3f} loss units")
    
    print("\nPerplexity Interpretation:")
    print("• Perplexity ≈ 'How many choices does the model think it has?'")
    print("• Random guessing: perplexity = vocabulary size")
    print("• Perfect prediction: perplexity = 1")
    print("• Good language models: perplexity < 50")
    print("• Great language models: perplexity < 20")

analyze_loss_and_perplexity(tracker)

## 5. Learning Rate Scheduling

Learning rate is crucial for training. Let's explore different learning rate schedules and see their effects.

In [None]:
def demonstrate_learning_rate_schedules():
    """Show different learning rate schedules."""
    
    max_steps = 100
    base_lr = 1e-3
    
    steps = np.arange(max_steps)
    
    # Different schedules
    schedules = {}
    
    # Constant
    schedules['Constant'] = np.full(max_steps, base_lr)
    
    # Linear decay
    schedules['Linear Decay'] = base_lr * (1 - steps / max_steps)
    
    # Cosine decay
    schedules['Cosine Decay'] = base_lr * 0.5 * (1 + np.cos(np.pi * steps / max_steps))
    
    # Warmup + cosine
    warmup_steps = 10
    warmup_cosine = np.zeros(max_steps)
    for i in range(max_steps):
        if i < warmup_steps:
            warmup_cosine[i] = base_lr * i / warmup_steps
        else:
            progress = (i - warmup_steps) / (max_steps - warmup_steps)
            warmup_cosine[i] = base_lr * 0.5 * (1 + np.cos(np.pi * progress))
    schedules['Warmup + Cosine'] = warmup_cosine
    
    # Exponential decay
    gamma = 0.95
    schedules['Exponential'] = base_lr * (gamma ** (steps / 10))
    
    # Plot schedules
    plt.figure(figsize=(12, 8))
    
    for name, schedule in schedules.items():
        plt.plot(steps, schedule, label=name, linewidth=2)
    
    plt.xlabel('Training Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedules')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.yscale('log')
    plt.show()
    
    print("Learning Rate Schedule Guidelines:")
    print("=" * 40)
    print("• Constant: Simple but may not converge optimally")
    print("• Linear Decay: Smooth reduction, good baseline")
    print("• Cosine Decay: Popular for transformers, smooth curves")
    print("• Warmup + Cosine: Stabilizes early training, then smooth decay")
    print("• Exponential: Fast initial decay, then slow")
    print()
    print("Best Practices:")
    print("• Use warmup for large models (prevents early instability)")
    print("• Cosine decay works well for most transformer training")
    print("• End with small LR for fine-tuning convergence")
    print("• Monitor loss curves to adjust schedule")

demonstrate_learning_rate_schedules()

# Implement cosine warmup schedule
class CosineWarmupScheduler:
    """Cosine learning rate schedule with warmup."""
    
    def __init__(self, optimizer, warmup_steps: int, max_steps: int, base_lr: float, min_lr: float = 0):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.max_steps = max_steps
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.step_count = 0
    
    def step(self):
        self.step_count += 1
        
        if self.step_count < self.warmup_steps:
            # Warmup phase
            lr = self.base_lr * self.step_count / self.warmup_steps
        else:
            # Cosine decay phase
            progress = (self.step_count - self.warmup_steps) / (self.max_steps - self.warmup_steps)
            progress = min(progress, 1.0)
            lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr

print("\n✅ Learning rate scheduler implemented!")

## 6. Regularization Techniques

Transformers can easily overfit, especially on small datasets. Let's explore regularization techniques.

In [None]:
def demonstrate_overfitting():
    """Show how overfitting manifests in language models."""
    
    print("Overfitting in Language Models")
    print("=" * 35)
    
    # Create a tiny dataset to encourage overfitting
    tiny_text = "The cat sat on the mat. The cat sat on the hat."
    tokenizer = create_tokenizer("simple")
    
    # Encode the text
    tokens = tokenizer.encode(tiny_text, add_special_tokens=False)
    print(f"Training text: '{tiny_text}'")
    print(f"Tokens: {tokens}")
    print(f"Unique tokens: {len(set(tokens))}")
    
    # Create dataset
    dataset = SimpleTextDataset(tiny_text, tokenizer, block_size=8)
    dataloader = create_dataloader(dataset, batch_size=2, shuffle=False)
    
    print(f"Training samples: {len(dataset)}")
    
    # Show samples
    print("\nTraining samples:")
    for i, (input_ids, target_ids) in enumerate(dataloader):
        if i < 3:  # Show first 3 batches
            print(f"Batch {i}: input shape {input_ids.shape}")
            for j in range(input_ids.shape[0]):
                input_text = tokenizer.decode(input_ids[j].tolist(), skip_special_tokens=True)
                target_text = tokenizer.decode(target_ids[j].tolist(), skip_special_tokens=True)
                print(f"  Input:  '{input_text}'")
                print(f"  Target: '{target_text}'")
    
    print("\nSigns of Overfitting:")
    print("• Training loss goes to zero but validation loss increases")
    print("• Model memorizes training data exactly")
    print("• Poor generalization to new text")
    print("• Generated text becomes repetitive or nonsensical")
    
    return dataset, dataloader

tiny_dataset, tiny_dataloader = demonstrate_overfitting()

def demonstrate_regularization_techniques():
    """Show different regularization methods."""
    
    print("\nRegularization Techniques for Transformers")
    print("=" * 45)
    
    techniques = {
        "Dropout": {
            "description": "Randomly zero out neurons during training",
            "implementation": "nn.Dropout(p=0.1) in attention and FFN",
            "effect": "Prevents co-adaptation of neurons"
        },
        "Weight Decay": {
            "description": "Add L2 penalty to weights",
            "implementation": "weight_decay=0.01 in optimizer",
            "effect": "Keeps weights small, improves generalization"
        },
        "Gradient Clipping": {
            "description": "Limit gradient magnitude",
            "implementation": "clip_grad_norm_(params, max_norm=1.0)",
            "effect": "Prevents exploding gradients"
        },
        "Early Stopping": {
            "description": "Stop when validation loss stops improving",
            "implementation": "Monitor validation loss, save best model",
            "effect": "Prevents overfitting to training data"
        },
        "Data Augmentation": {
            "description": "Increase effective dataset size",
            "implementation": "Paraphrasing, back-translation, masking",
            "effect": "More diverse training examples"
        }
    }
    
    for technique, info in techniques.items():
        print(f"\n{technique}:")
        print(f"  Description: {info['description']}")
        print(f"  Implementation: {info['implementation']}")
        print(f"  Effect: {info['effect']}")
    
    # Visualize dropout effect
    print("\nDropout Visualization:")
    
    # Simulate dropout on a tensor
    x = torch.ones(4, 8)  # 4x8 tensor of ones
    dropout = nn.Dropout(p=0.3)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Original
    axes[0].imshow(x, cmap='Blues')
    axes[0].set_title('Original Activations')
    axes[0].set_xlabel('Feature')
    axes[0].set_ylabel('Example')
    
    # With dropout (training mode)
    dropout.train()
    x_dropout = dropout(x)
    axes[1].imshow(x_dropout, cmap='Blues')
    axes[1].set_title('With Dropout (Training)')
    axes[1].set_xlabel('Feature')
    
    # Without dropout (eval mode)
    dropout.eval()
    x_eval = dropout(x)
    axes[2].imshow(x_eval, cmap='Blues')
    axes[2].set_title('Without Dropout (Evaluation)')
    axes[2].set_xlabel('Feature')
    
    plt.tight_layout()
    plt.show()
    
    print("Notice how dropout randomly zeros neurons during training but not evaluation!")

demonstrate_regularization_techniques()

## 7. Monitoring Training Progress

Good monitoring is essential for successful training. Let's implement comprehensive monitoring.

In [None]:
def comprehensive_training_monitoring():
    """Demonstrate comprehensive training monitoring."""
    
    print("Comprehensive Training Monitoring")
    print("=" * 40)
    
    # Create a slightly larger model for realistic monitoring
    config = create_model_config("small")
    config["vocab_size"] = 200
    config["n_layers"] = 3  # Smaller for faster training
    model = GPTModel(**config)
    
    # More comprehensive training text
    training_text = """
    The transformer architecture revolutionized natural language processing.
    Attention mechanisms allow models to focus on relevant parts of the input.
    Large language models demonstrate emergent capabilities at scale.
    Training requires careful optimization and regularization techniques.
    Deep learning continues to advance the field of artificial intelligence.
    Neural networks learn complex patterns from vast amounts of data.
    Machine learning algorithms can generalize to unseen examples.
    The future of AI depends on responsible development and deployment.
    """
    
    tokenizer = create_tokenizer("simple")
    dataset = SimpleTextDataset(training_text, tokenizer, block_size=24)
    dataloader = create_dataloader(dataset, batch_size=3, shuffle=True)
    
    # Setup optimizer with proper settings
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95))
    criterion = nn.CrossEntropyLoss()
    
    # Learning rate scheduler
    total_steps = len(dataloader) * 3  # 3 epochs
    scheduler = CosineWarmupScheduler(optimizer, warmup_steps=10, max_steps=total_steps, base_lr=3e-4)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
    print(f"Dataset: {len(dataset)} samples")
    print(f"Total training steps: {total_steps}")
    
    # Training loop with comprehensive monitoring
    metrics = {
        'steps': [],
        'losses': [],
        'learning_rates': [],
        'grad_norms': [],
        'weight_norms': [],
        'samples': []
    }
    
    step = 0
    
    for epoch in range(3):
        print(f"\nEpoch {epoch + 1}/3")
        
        for batch_idx, batch in enumerate(dataloader):
            # Training step
            model.train()
            input_ids, target_ids = batch
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)
            
            optimizer.zero_grad()
            logits, _ = model(input_ids)
            
            # Calculate loss
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            
            # Monitor gradients
            total_grad_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    total_grad_norm += p.grad.data.norm(2).item() ** 2
            total_grad_norm = total_grad_norm ** 0.5
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            current_lr = scheduler.step()
            
            # Monitor weights
            total_weight_norm = 0
            for p in model.parameters():
                total_weight_norm += p.data.norm(2).item() ** 2
            total_weight_norm = total_weight_norm ** 0.5
            
            # Record metrics
            metrics['steps'].append(step)
            metrics['losses'].append(loss.item())
            metrics['learning_rates'].append(current_lr)
            metrics['grad_norms'].append(total_grad_norm)
            metrics['weight_norms'].append(total_weight_norm)
            
            # Generate sample every 10 steps
            if step % 10 == 0:
                sample = generate_sample(model, tokenizer, "The", max_length=10)
                metrics['samples'].append((step, sample))
                print(f"Step {step:3d}: Loss={loss.item():.3f}, LR={current_lr:.2e}, Sample='{sample}'")
            
            step += 1
    
    return metrics

# Run comprehensive training
metrics = comprehensive_training_monitoring()

# Plot comprehensive metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(metrics['steps'], metrics['losses'], 'b-', alpha=0.7)
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].grid(True, alpha=0.3)

# Learning rate
axes[0, 1].plot(metrics['steps'], metrics['learning_rates'], 'r-', alpha=0.7)
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Learning Rate')
axes[0, 1].set_title('Learning Rate Schedule')
axes[0, 1].grid(True, alpha=0.3)

# Gradient norms
axes[1, 0].plot(metrics['steps'], metrics['grad_norms'], 'g-', alpha=0.7)
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Gradient Norm')
axes[1, 0].set_title('Gradient Norms')
axes[1, 0].grid(True, alpha=0.3)

# Weight norms
axes[1, 1].plot(metrics['steps'], metrics['weight_norms'], 'm-', alpha=0.7)
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Weight Norm')
axes[1, 1].set_title('Weight Norms')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTraining Quality Indicators:")
print(f"• Loss decreased from {metrics['losses'][0]:.3f} to {metrics['losses'][-1]:.3f}")
print(f"• Gradient norms: {np.mean(metrics['grad_norms']):.3f} (should be stable, not too large)")
print(f"• Weight norms growing: {metrics['weight_norms'][-1] > metrics['weight_norms'][0]} (expected during training)")
print(f"• Learning rate properly scheduled: {metrics['learning_rates'][0]:.2e} → {metrics['learning_rates'][-1]:.2e}")

## Summary

In this notebook, we've explored the complete transformer training process:

1. **Language Modeling Objective** - Next token prediction with cross-entropy loss
2. **Training Loop** - Forward pass, loss calculation, backpropagation, optimization
3. **Progress Monitoring** - Loss curves, perplexity, sample generation
4. **Learning Rate Scheduling** - Warmup, cosine decay, and their effects
5. **Regularization** - Dropout, weight decay, gradient clipping
6. **Comprehensive Monitoring** - Tracking gradients, weights, and training health

### Key Training Insights:

- **Loss is everything**: Cross-entropy loss drives all learning
- **Perplexity matters**: More interpretable than raw loss
- **Learning rate is critical**: Use warmup + cosine decay
- **Regularization prevents overfitting**: Dropout, weight decay, clipping
- **Monitor everything**: Loss, gradients, weights, samples
- **Scaling laws**: More data + bigger models = better performance

### Best Practices:

- Start with a small model and overfit on a tiny dataset
- Use proper learning rate scheduling (warmup + decay)
- Monitor gradient norms (clip if > 1.0)
- Generate samples regularly to check progress
- Save checkpoints and implement early stopping
- Use mixed precision training for efficiency

### Training Stages:

1. **Early**: Loss drops quickly, gradients large, samples nonsensical
2. **Middle**: Steady improvement, learning meaningful patterns
3. **Late**: Slow improvement, fine-tuning, risk of overfitting

The magic of transformers is that this simple next-token prediction objective leads to emergence of complex language understanding, reasoning, and generation capabilities!

Next, we'll explore the fascinating world of text generation strategies!