# Modern Transformer Architecture Improvements

This notebook explores the key architectural innovations that have improved transformer performance, stability, and efficiency since the original "Attention Is All You Need" paper. We'll implement and compare modern components that are now standard in state-of-the-art models.

## Learning Objectives

By the end of this notebook, you will understand:
1. **RMSNorm vs LayerNorm**: Why RMSNorm is preferred in modern models
2. **SwiGLU Activation**: How gated activations improve transformer performance
3. **Rotary Position Embedding (RoPE)**: Relative position encoding that scales well
4. **Pre-norm vs Post-norm**: Architecture choices and their impact on training
5. **Modern Integration**: How these components work together in practice

## Prerequisites

- Understanding of basic transformer architecture (notebooks 01-05)
- Familiarity with normalization techniques
- Basic knowledge of activation functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional, Tuple, List
import time
from dataclasses import dataclass

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. RMSNorm vs LayerNorm

Root Mean Square Layer Normalization (RMSNorm) simplifies LayerNorm by removing the mean centering operation while maintaining most of the benefits. It's used in models like LLaMA, PaLM, and many others.

In [None]:
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    
    RMSNorm normalizes using only the root mean square of the inputs,
    without centering (subtracting the mean).
    
    Formula: RMSNorm(x) = x / RMS(x) * γ
    where RMS(x) = sqrt(mean(x²) + ε)
    """
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        
        # Normalize and scale
        x_normalized = x / rms
        return self.weight * x_normalized


class LayerNormComparison(nn.Module):
    """Wrapper to compare LayerNorm and RMSNorm side by side."""
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model, eps=eps)
        self.rms_norm = RMSNorm(d_model, eps=eps)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        ln_out = self.layer_norm(x)
        rms_out = self.rms_norm(x)
        return ln_out, rms_out


def analyze_normalization_techniques():
    """Analyze and compare LayerNorm vs RMSNorm."""
    d_model = 512
    batch_size = 32
    seq_len = 128
    
    # Create comparison module
    norm_comparison = LayerNormComparison(d_model).to(device)
    
    # Test with different input distributions
    test_cases = {
        'Standard Normal': torch.randn(batch_size, seq_len, d_model).to(device),
        'Shifted Distribution': torch.randn(batch_size, seq_len, d_model).to(device) + 2.0,
        'High Variance': torch.randn(batch_size, seq_len, d_model).to(device) * 5.0,
        'Asymmetric': torch.abs(torch.randn(batch_size, seq_len, d_model)).to(device),
    }
    
    results = {}
    
    for case_name, input_tensor in test_cases.items():
        ln_out, rms_out = norm_comparison(input_tensor)
        
        # Compute statistics
        input_stats = {
            'mean': input_tensor.mean().item(),
            'std': input_tensor.std().item(),
            'min': input_tensor.min().item(),
            'max': input_tensor.max().item()
        }
        
        ln_stats = {
            'mean': ln_out.mean().item(),
            'std': ln_out.std().item(),
            'min': ln_out.min().item(),
            'max': ln_out.max().item()
        }
        
        rms_stats = {
            'mean': rms_out.mean().item(),
            'std': rms_out.std().item(),
            'min': rms_out.min().item(),
            'max': rms_out.max().item()
        }
        
        # Compute correlation between outputs
        correlation = F.cosine_similarity(
            ln_out.flatten(), rms_out.flatten(), dim=0
        ).item()
        
        results[case_name] = {
            'input': input_stats,
            'layernorm': ln_stats,
            'rmsnorm': rms_stats,
            'correlation': correlation
        }
    
    return results

# Run analysis
norm_results = analyze_normalization_techniques()

# Display results
print("Normalization Technique Comparison:")
print("=" * 80)

for case_name, stats in norm_results.items():
    print(f"\n{case_name}:")
    print(f"  Input    - Mean: {stats['input']['mean']:7.3f}, Std: {stats['input']['std']:7.3f}")
    print(f"  LayerNorm- Mean: {stats['layernorm']['mean']:7.3f}, Std: {stats['layernorm']['std']:7.3f}")
    print(f"  RMSNorm  - Mean: {stats['rmsnorm']['mean']:7.3f}, Std: {stats['rmsnorm']['std']:7.3f}")
    print(f"  Correlation: {stats['correlation']:.4f}")

# Visualize the comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

case_names = list(norm_results.keys())
correlations = [norm_results[case]['correlation'] for case in case_names]
ln_stds = [norm_results[case]['layernorm']['std'] for case in case_names]
rms_stds = [norm_results[case]['rmsnorm']['std'] for case in case_names]
ln_means = [abs(norm_results[case]['layernorm']['mean']) for case in case_names]
rms_means = [abs(norm_results[case]['rmsnorm']['mean']) for case in case_names]

# Correlation plot
axes[0].bar(case_names, correlations, alpha=0.7, color='skyblue')
axes[0].set_title('Output Correlation: LayerNorm vs RMSNorm')
axes[0].set_ylabel('Cosine Similarity')
axes[0].tick_params(axis='x', rotation=45)
axes[0].axhline(y=0.95, color='red', linestyle='--', alpha=0.5, label='High Correlation')
axes[0].legend()

# Standard deviation comparison
x = np.arange(len(case_names))
width = 0.35
axes[1].bar(x - width/2, ln_stds, width, label='LayerNorm', alpha=0.7)
axes[1].bar(x + width/2, rms_stds, width, label='RMSNorm', alpha=0.7)
axes[1].set_title('Output Standard Deviation')
axes[1].set_ylabel('Standard Deviation')
axes[1].set_xticks(x)
axes[1].set_xticklabels(case_names, rotation=45)
axes[1].legend()

# Mean comparison (absolute values)
axes[2].bar(x - width/2, ln_means, width, label='LayerNorm', alpha=0.7)
axes[2].bar(x + width/2, rms_means, width, label='RMSNorm', alpha=0.7)
axes[2].set_title('Output Mean (Absolute Value)')
axes[2].set_ylabel('|Mean|')
axes[2].set_xticks(x)
axes[2].set_xticklabels(case_names, rotation=45)
axes[2].legend()

# Parameter count comparison
d_model_sizes = [128, 256, 512, 1024, 2048, 4096]
ln_params = [d * 2 for d in d_model_sizes]  # weight + bias
rms_params = [d * 1 for d in d_model_sizes]  # only weight
reduction = [(ln - rms) / ln * 100 for ln, rms in zip(ln_params, rms_params)]

axes[3].plot(d_model_sizes, reduction, 'o-', linewidth=2, markersize=8, color='green')
axes[3].set_title('Parameter Reduction: RMSNorm vs LayerNorm')
axes[3].set_xlabel('Model Dimension')
axes[3].set_ylabel('Parameter Reduction (%)')
axes[3].grid(True, alpha=0.3)
axes[3].axhline(y=50, color='red', linestyle='--', alpha=0.5, label='50% Reduction')
axes[3].legend()

plt.tight_layout()
plt.show()

print(f"\n🔬 Key Insights:")
print(f"• RMSNorm achieves 50% parameter reduction (no bias terms)")
print(f"• High correlation with LayerNorm (typically >0.95)")
print(f"• RMSNorm preserves input mean, doesn't center to zero")
print(f"• Simpler computation, better numerical stability")

## 2. SwiGLU Activation Function

SwiGLU (Swish-Gated Linear Unit) combines the Swish activation with a gating mechanism. It's used in models like PaLM and LLaMA and has shown superior performance compared to ReLU and GELU.

In [None]:
class SwiGLU(nn.Module):
    """
    SwiGLU activation function: SwiGLU(x) = Swish(xW + b) ⊗ (xV + c)
    
    Where:
    - Swish(x) = x * sigmoid(βx), typically β=1
    - ⊗ denotes element-wise multiplication (gating)
    - W, V are linear transformations
    """
    
    def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True):
        super().__init__()
        self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.up_proj = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.down_proj = nn.Linear(hidden_dim, input_dim, bias=bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gate path: apply Swish activation
        gate = F.silu(self.gate_proj(x))  # SiLU is Swish with β=1
        
        # Up path: linear transformation
        up = self.up_proj(x)
        
        # Element-wise multiplication (gating)
        gated = gate * up
        
        # Down projection
        return self.down_proj(gated)


class ModernFeedForward(nn.Module):
    """Modern feed-forward network with choice of activation."""
    
    def __init__(self, d_model: int, d_ff: int, activation: str = 'swiglu', dropout: float = 0.1):
        super().__init__()
        self.activation_type = activation
        
        if activation == 'swiglu':
            # SwiGLU requires 2/3 * d_ff hidden dimension to maintain parameter count
            hidden_dim = int(2 * d_ff / 3)
            self.ffn = SwiGLU(d_model, hidden_dim)
        else:
            # Standard FFN with specified activation
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_ff),
                self._get_activation(activation),
                nn.Dropout(dropout),
                nn.Linear(d_ff, d_model),
                nn.Dropout(dropout)
            )
    
    def _get_activation(self, activation: str) -> nn.Module:
        activations = {
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),  # Swish
            'mish': nn.Mish(),
        }
        return activations.get(activation, nn.ReLU())
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffn(x)


def compare_activation_functions():
    """Compare different activation functions in feed-forward networks."""
    d_model = 512
    d_ff = 2048
    batch_size = 16
    seq_len = 64
    
    # Create different FFN variants
    activations = ['relu', 'gelu', 'silu', 'swiglu']
    ffns = {}
    
    for act in activations:
        ffns[act] = ModernFeedForward(d_model, d_ff, activation=act).to(device)
    
    # Test input
    x = torch.randn(batch_size, seq_len, d_model).to(device)
    
    results = {}
    
    for act_name, ffn in ffns.items():
        # Forward pass
        output = ffn(x)
        
        # Compute statistics
        param_count = sum(p.numel() for p in ffn.parameters())
        
        results[act_name] = {
            'output_mean': output.mean().item(),
            'output_std': output.std().item(),
            'output_range': (output.min().item(), output.max().item()),
            'param_count': param_count,
            'has_gating': 'swiglu' in act_name.lower(),
        }
    
    return results, x, {act: ffn(x) for act, ffn in ffns.items()}


# Visualize activation function properties
def plot_activation_functions():
    """Plot the activation functions themselves."""
    x = torch.linspace(-5, 5, 1000)
    
    activations = {
        'ReLU': F.relu(x),
        'GELU': F.gelu(x),
        'SiLU/Swish': F.silu(x),
        'Mish': F.mish(x),
    }
    
    # For SwiGLU, show the gating behavior
    gate_values = F.silu(x)
    up_values = x  # Simplified - in practice this would be a linear transformation
    swiglu_like = gate_values * up_values
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Standard activations
    for name, values in activations.items():
        ax1.plot(x.numpy(), values.numpy(), label=name, linewidth=2)
    
    ax1.set_title('Activation Functions')
    ax1.set_xlabel('Input')
    ax1.set_ylabel('Output')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    ax1.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax1.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    
    # SwiGLU components
    ax2.plot(x.numpy(), gate_values.numpy(), label='Gate (SiLU)', linewidth=2, alpha=0.7)
    ax2.plot(x.numpy(), up_values.numpy(), label='Up (Linear)', linewidth=2, alpha=0.7)
    ax2.plot(x.numpy(), swiglu_like.numpy(), label='Gated Output', linewidth=3)
    
    ax2.set_title('SwiGLU Gating Mechanism')
    ax2.set_xlabel('Input')
    ax2.set_ylabel('Output')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    ax2.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    ax2.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Run comparisons
plot_activation_functions()
ffn_results, test_input, ffn_outputs = compare_activation_functions()

# Display results
print("\nFeed-Forward Network Comparison:")
print("Activation\tParam Count\tOutput Mean\tOutput Std\tRange")
print("-" * 70)

for act_name, stats in ffn_results.items():
    range_str = f"[{stats['output_range'][0]:.2f}, {stats['output_range'][1]:.2f}]"
    print(f"{act_name:<12}\t{stats['param_count']:>8,}\t{stats['output_mean']:>9.4f}\t"
          f"{stats['output_std']:>9.4f}\t{range_str}")

# Visualize output distributions
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()

for idx, (act_name, output) in enumerate(ffn_outputs.items()):
    if idx < 4:  # We have 4 subplots
        output_flat = output.detach().cpu().flatten().numpy()
        axes[idx].hist(output_flat, bins=50, alpha=0.7, density=True)
        axes[idx].set_title(f'{act_name.upper()} Output Distribution')
        axes[idx].set_xlabel('Output Value')
        axes[idx].set_ylabel('Density')
        axes[idx].axvline(x=output_flat.mean(), color='red', linestyle='--', 
                         label=f'Mean: {output_flat.mean():.3f}')
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n🎯 Key Insights:")
print(f"• SwiGLU uses gating mechanism for better gradient flow")
print(f"• SiLU/Swish provides smooth, non-monotonic activation")
print(f"• Gated activations generally show improved performance")
print(f"• Parameter count varies due to gating mechanism in SwiGLU")

## 3. Rotary Position Embedding (RoPE)

RoPE encodes positional information by rotating the query and key vectors in a way that naturally encodes relative positions. It's used in models like GPT-J, GPT-NeoX, and LLaMA.

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE).
    
    RoPE rotates query and key vectors by an angle proportional to their position.
    This naturally encodes relative positional information into the attention mechanism.
    """
    
    def __init__(self, d_model: int, max_len: int = 8192, base: float = 10000.0):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.base = base
        
        # Precompute frequency matrix
        self._build_cache(max_len)
    
    def _build_cache(self, max_len: int):
        """Build rotation matrices for all positions."""
        # Frequency for each dimension pair
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.d_model, 2).float() / self.d_model))
        
        # Position indices
        position = torch.arange(max_len).float()
        
        # Compute angles: outer product of positions and frequencies
        angles = torch.outer(position, inv_freq)  # [max_len, d_model//2]
        
        # Precompute cos and sin
        cos_cached = torch.cos(angles)  # [max_len, d_model//2]
        sin_cached = torch.sin(angles)  # [max_len, d_model//2]
        
        # Register as buffers (non-trainable parameters)
        self.register_buffer('cos_cached', cos_cached)
        self.register_buffer('sin_cached', sin_cached)
    
    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """Rotate half the dimensions of x."""
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)
    
    def apply_rotary_emb(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        """Apply rotary embedding to tensor x."""
        # x shape: [batch, heads, seq_len, head_dim]
        # cos, sin shape: [seq_len, head_dim//2]
        
        # Expand cos and sin to match x dimensions
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim//2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim//2]
        
        # Duplicate cos and sin for full head dimension
        cos = torch.cat([cos, cos], dim=-1)  # [1, 1, seq_len, head_dim]
        sin = torch.cat([sin, sin], dim=-1)  # [1, 1, seq_len, head_dim]
        
        # Apply rotation
        return (x * cos) + (self.rotate_half(x) * sin)
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, 
                position_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply RoPE to query and key tensors.
        
        Args:
            q: Query tensor [batch, heads, seq_len, head_dim]
            k: Key tensor [batch, heads, seq_len, head_dim]
            position_offset: Offset for positions (useful for KV caching)
        """
        seq_len = q.size(2)
        
        # Get position range
        start_pos = position_offset
        end_pos = start_pos + seq_len
        
        # Extract cos and sin for current positions
        cos = self.cos_cached[start_pos:end_pos]
        sin = self.sin_cached[start_pos:end_pos]
        
        # Apply rotation
        q_rotated = self.apply_rotary_emb(q, cos, sin)
        k_rotated = self.apply_rotary_emb(k, cos, sin)
        
        return q_rotated, k_rotated


class RoPEAttention(nn.Module):
    """Multi-head attention with Rotary Position Embedding."""
    
    def __init__(self, d_model: int, n_heads: int, max_len: int = 8192, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        
        # Rotary position embedding
        self.rope = RotaryPositionalEmbedding(self.d_k, max_len)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
                position_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, seq_len, d_model = x.shape
        
        # Compute Q, K, V
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Apply RoPE to Q and K
        Q, K = self.rope(Q, K, position_offset)
        
        # Compute attention
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        out = torch.matmul(attn_weights, V)
        
        # Reshape and project output
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = self.w_o(out)
        
        return out, attn_weights


def analyze_rope_properties():
    """Analyze how RoPE encodes positional information."""
    d_model = 128
    max_len = 64
    
    rope = RotaryPositionalEmbedding(d_model, max_len)
    
    # Create sample queries and keys
    batch_size, n_heads, seq_len = 1, 1, max_len
    
    # Use identity matrices to see pure rotational effect
    Q = torch.eye(d_model).unsqueeze(0).unsqueeze(0).expand(batch_size, n_heads, seq_len, d_model)
    K = Q.clone()
    
    # Apply RoPE
    Q_rope, K_rope = rope(Q, K)
    
    # Compute attention patterns
    attn_no_rope = torch.matmul(Q, K.transpose(-2, -1))
    attn_with_rope = torch.matmul(Q_rope, K_rope.transpose(-2, -1))
    
    return attn_no_rope[0, 0], attn_with_rope[0, 0], rope


# Analyze RoPE
attn_no_rope, attn_with_rope, rope_module = analyze_rope_properties()

# Visualize RoPE effects
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Attention patterns
im1 = axes[0, 0].imshow(attn_no_rope.detach().numpy(), cmap='RdBu_r', vmin=-1, vmax=1)
axes[0, 0].set_title('Attention Without RoPE')
axes[0, 0].set_xlabel('Key Position')
axes[0, 0].set_ylabel('Query Position')
plt.colorbar(im1, ax=axes[0, 0])

im2 = axes[0, 1].imshow(attn_with_rope.detach().numpy(), cmap='RdBu_r', vmin=-1, vmax=1)
axes[0, 1].set_title('Attention With RoPE')
axes[0, 1].set_xlabel('Key Position')
axes[0, 1].set_ylabel('Query Position')
plt.colorbar(im2, ax=axes[0, 1])

# Difference
diff = attn_with_rope - attn_no_rope
im3 = axes[0, 2].imshow(diff.detach().numpy(), cmap='RdBu_r')
axes[0, 2].set_title('Difference (RoPE - No RoPE)')
axes[0, 2].set_xlabel('Key Position')
axes[0, 2].set_ylabel('Query Position')
plt.colorbar(im3, ax=axes[0, 2])

# RoPE frequency analysis
freqs = rope_module.cos_cached
positions = torch.arange(freqs.size(0))

# Plot a few frequency components
for i in range(0, min(8, freqs.size(1)), 2):
    axes[1, 0].plot(positions.numpy(), freqs[:, i].numpy(), 
                   label=f'Dim {i}', alpha=0.7)
axes[1, 0].set_title('RoPE Cosine Components')
axes[1, 0].set_xlabel('Position')
axes[1, 0].set_ylabel('Cosine Value')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Frequency spectrum
d_model_test = 64
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model_test, 2).float() / d_model_test))
axes[1, 1].semilogy(range(len(inv_freq)), inv_freq.numpy(), 'o-')
axes[1, 1].set_title('RoPE Frequency Spectrum')
axes[1, 1].set_xlabel('Dimension Pair')
axes[1, 1].set_ylabel('Frequency (log scale)')
axes[1, 1].grid(True, alpha=0.3)

# Relative position bias from RoPE
relative_positions = torch.arange(-32, 33)
q_pos = 32  # Fix query at position 32
k_positions = q_pos + relative_positions

# Clamp to valid range
k_positions = torch.clamp(k_positions, 0, rope_module.max_len - 1)

# Compute relative attention scores
q_cos = rope_module.cos_cached[q_pos]
q_sin = rope_module.sin_cached[q_pos]
k_cos = rope_module.cos_cached[k_positions]
k_sin = rope_module.sin_cached[k_positions]

# Simplified relative score (sum over dimensions)
relative_scores = (q_cos * k_cos + q_sin * k_sin).sum(dim=-1)

axes[1, 2].plot(relative_positions.numpy(), relative_scores.numpy(), 'o-', linewidth=2)
axes[1, 2].set_title('RoPE Relative Position Bias')
axes[1, 2].set_xlabel('Relative Position (k - q)')
axes[1, 2].set_ylabel('Attention Score')
axes[1, 2].grid(True, alpha=0.3)
axes[1, 2].axvline(x=0, color='red', linestyle='--', alpha=0.5, label='Same Position')
axes[1, 2].legend()

plt.tight_layout()
plt.show()

print(f"\n🔄 RoPE Key Insights:")
print(f"• Encodes relative positions through rotation angles")
print(f"• No additional parameters needed (frequency-based)")
print(f"• Naturally handles sequences longer than training length")
print(f"• Creates smooth distance-based attention decay")
print(f"• Different frequencies for different dimension pairs")

## 4. Pre-norm vs Post-norm Architecture

The placement of layer normalization significantly affects training dynamics and model performance. Let's compare pre-norm and post-norm architectures.

In [None]:
class PreNormTransformerBlock(nn.Module):
    """Transformer block with pre-normalization (modern approach)."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, 
                 dropout: float = 0.1, norm_type: str = 'layernorm'):
        super().__init__()
        
        self.attention = RoPEAttention(d_model, n_heads, dropout=dropout)
        self.feed_forward = ModernFeedForward(d_model, d_ff, 'swiglu', dropout)
        
        # Normalization layers
        if norm_type == 'rmsnorm':
            self.norm1 = RMSNorm(d_model)
            self.norm2 = RMSNorm(d_model)
        else:
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # Pre-norm: normalize before attention
        normed_x = self.norm1(x)
        attn_out, attn_weights = self.attention(normed_x, mask)
        x = x + self.dropout(attn_out)
        
        # Pre-norm: normalize before feed-forward
        normed_x = self.norm2(x)
        ff_out = self.feed_forward(normed_x)
        x = x + self.dropout(ff_out)
        
        return x, attn_weights


class PostNormTransformerBlock(nn.Module):
    """Transformer block with post-normalization (original approach)."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, 
                 dropout: float = 0.1, norm_type: str = 'layernorm'):
        super().__init__()
        
        self.attention = RoPEAttention(d_model, n_heads, dropout=dropout)
        self.feed_forward = ModernFeedForward(d_model, d_ff, 'swiglu', dropout)
        
        # Normalization layers
        if norm_type == 'rmsnorm':
            self.norm1 = RMSNorm(d_model)
            self.norm2 = RMSNorm(d_model)
        else:
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # Post-norm: normalize after residual connection
        attn_out, attn_weights = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Post-norm: normalize after residual connection
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        
        return x, attn_weights


def analyze_norm_placement():
    """Analyze the effects of pre-norm vs post-norm placement."""
    d_model = 256
    n_heads = 4
    d_ff = 1024
    batch_size = 8
    seq_len = 32
    
    # Create both variants
    pre_norm_block = PreNormTransformerBlock(d_model, n_heads, d_ff).to(device)
    post_norm_block = PostNormTransformerBlock(d_model, n_heads, d_ff).to(device)
    
    # Test input with varying magnitudes
    test_cases = {
        'Small': torch.randn(batch_size, seq_len, d_model).to(device) * 0.1,
        'Normal': torch.randn(batch_size, seq_len, d_model).to(device),
        'Large': torch.randn(batch_size, seq_len, d_model).to(device) * 5.0,
    }
    
    results = {}
    
    for case_name, input_tensor in test_cases.items():
        # Forward pass through both architectures
        pre_out, _ = pre_norm_block(input_tensor)
        post_out, _ = post_norm_block(input_tensor)
        
        # Compute gradients to analyze gradient flow
        pre_loss = pre_out.sum()
        post_loss = post_out.sum()
        
        # Backward pass
        pre_norm_block.zero_grad()
        pre_loss.backward(retain_graph=True)
        pre_grad_norms = [p.grad.norm().item() for p in pre_norm_block.parameters() 
                         if p.grad is not None]
        
        post_norm_block.zero_grad()
        post_loss.backward(retain_graph=True)
        post_grad_norms = [p.grad.norm().item() for p in post_norm_block.parameters() 
                          if p.grad is not None]
        
        results[case_name] = {
            'input_norm': input_tensor.norm().item(),
            'pre_output_norm': pre_out.norm().item(),
            'post_output_norm': post_out.norm().item(),
            'pre_grad_norm_mean': np.mean(pre_grad_norms),
            'post_grad_norm_mean': np.mean(post_grad_norms),
            'pre_grad_norm_std': np.std(pre_grad_norms),
            'post_grad_norm_std': np.std(post_grad_norms),
        }
    
    return results


# Simulate training dynamics
def simulate_training_dynamics():
    """Simulate training to see stability differences."""
    d_model = 128
    n_heads = 4
    d_ff = 512
    batch_size = 4
    seq_len = 16
    n_steps = 50
    
    # Create models
    pre_norm_block = PreNormTransformerBlock(d_model, n_heads, d_ff).to(device)
    post_norm_block = PostNormTransformerBlock(d_model, n_heads, d_ff).to(device)
    
    # Optimizers
    pre_optimizer = torch.optim.Adam(pre_norm_block.parameters(), lr=1e-3)
    post_optimizer = torch.optim.Adam(post_norm_block.parameters(), lr=1e-3)
    
    # Training loop
    pre_losses = []
    post_losses = []
    pre_grad_norms = []
    post_grad_norms = []
    
    for step in range(n_steps):
        # Generate random input and target
        x = torch.randn(batch_size, seq_len, d_model).to(device)
        target = torch.randn(batch_size, seq_len, d_model).to(device)
        
        # Pre-norm forward and backward
        pre_optimizer.zero_grad()
        pre_out, _ = pre_norm_block(x)
        pre_loss = F.mse_loss(pre_out, target)
        pre_loss.backward()
        
        # Compute gradient norm
        total_norm = 0
        for p in pre_norm_block.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        pre_grad_norm = total_norm ** 0.5
        
        pre_optimizer.step()
        
        # Post-norm forward and backward
        post_optimizer.zero_grad()
        post_out, _ = post_norm_block(x)
        post_loss = F.mse_loss(post_out, target)
        post_loss.backward()
        
        # Compute gradient norm
        total_norm = 0
        for p in post_norm_block.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        post_grad_norm = total_norm ** 0.5
        
        post_optimizer.step()
        
        # Store metrics
        pre_losses.append(pre_loss.item())
        post_losses.append(post_loss.item())
        pre_grad_norms.append(pre_grad_norm)
        post_grad_norms.append(post_grad_norm)
    
    return pre_losses, post_losses, pre_grad_norms, post_grad_norms


# Run analyses
norm_analysis = analyze_norm_placement()
pre_losses, post_losses, pre_grads, post_grads = simulate_training_dynamics()

# Display norm placement analysis
print("Pre-norm vs Post-norm Analysis:")
print("=" * 80)
print("Input Type\tInput Norm\tPre-Out Norm\tPost-Out Norm\tPre-Grad\tPost-Grad")
print("-" * 80)

for case_name, stats in norm_analysis.items():
    print(f"{case_name:<12}\t{stats['input_norm']:>8.3f}\t{stats['pre_output_norm']:>10.3f}\t"
          f"{stats['post_output_norm']:>11.3f}\t{stats['pre_grad_norm_mean']:>7.4f}\t"
          f"{stats['post_grad_norm_mean']:>8.4f}")

# Visualize training dynamics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
steps = range(len(pre_losses))
axes[0, 0].plot(steps, pre_losses, label='Pre-norm', linewidth=2)
axes[0, 0].plot(steps, post_losses, label='Post-norm', linewidth=2)
axes[0, 0].set_title('Training Loss Comparison')
axes[0, 0].set_xlabel('Training Step')
axes[0, 0].set_ylabel('MSE Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')

# Gradient norms
axes[0, 1].plot(steps, pre_grads, label='Pre-norm', linewidth=2, alpha=0.7)
axes[0, 1].plot(steps, post_grads, label='Post-norm', linewidth=2, alpha=0.7)
axes[0, 1].set_title('Gradient Norm Comparison')
axes[0, 1].set_xlabel('Training Step')
axes[0, 1].set_ylabel('Gradient Norm')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_yscale('log')

# Architecture comparison
architectures = ['Pre-norm', 'Post-norm']
stability_scores = [np.std(pre_grads), np.std(post_grads)]
convergence_speeds = [len(pre_losses) - np.argmin(pre_losses), 
                     len(post_losses) - np.argmin(post_losses)]

axes[1, 0].bar(architectures, stability_scores, alpha=0.7)
axes[1, 0].set_title('Gradient Stability (Lower is Better)')
axes[1, 0].set_ylabel('Gradient Std Dev')

# Final loss comparison
final_losses = [pre_losses[-1], post_losses[-1]]
axes[1, 1].bar(architectures, final_losses, alpha=0.7, color=['orange', 'green'])
axes[1, 1].set_title('Final Training Loss')
axes[1, 1].set_ylabel('MSE Loss')

plt.tight_layout()
plt.show()

print(f"\n🏗️ Architecture Insights:")
print(f"• Pre-norm generally provides more stable training")
print(f"• Pre-norm has better gradient flow characteristics")
print(f"• Post-norm may achieve slightly better final performance")
print(f"• Pre-norm is preferred for very deep networks")
print(f"• Modern models (GPT-3, LLaMA) predominantly use pre-norm")

## 5. Putting It All Together: Modern Transformer Block

Let's create a complete modern transformer block that incorporates all the improvements we've discussed.

In [None]:
class ModernTransformerBlock(nn.Module):
    """
    Modern transformer block incorporating all improvements:
    - RMSNorm for normalization
    - SwiGLU for activation
    - RoPE for positional encoding
    - Pre-norm architecture
    """
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, 
                 max_len: int = 8192, dropout: float = 0.1):
        super().__init__()
        
        # Multi-head attention with RoPE
        self.attention = RoPEAttention(d_model, n_heads, max_len, dropout)
        
        # Feed-forward with SwiGLU
        self.feed_forward = ModernFeedForward(d_model, d_ff, 'swiglu', dropout)
        
        # RMSNorm layers (pre-norm style)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
                position_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of modern transformer block.
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Optional attention mask
            position_offset: Position offset for RoPE (useful for KV caching)
        """
        # Pre-norm attention with residual connection
        normed_x = self.norm1(x)
        attn_out, attn_weights = self.attention(normed_x, mask, position_offset)
        x = x + self.dropout(attn_out)
        
        # Pre-norm feed-forward with residual connection
        normed_x = self.norm2(x)
        ff_out = self.feed_forward(normed_x)
        x = x + self.dropout(ff_out)
        
        return x, attn_weights


class ModernTransformer(nn.Module):
    """Complete modern transformer model."""
    
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, 
                 n_heads: int, d_ff: int, max_len: int = 8192, 
                 dropout: float = 0.1, tie_weights: bool = True):
        super().__init__()
        
        self.d_model = d_model
        self.n_layers = n_layers
        
        # Token embeddings (no positional embeddings - using RoPE)
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            ModernTransformerBlock(d_model, n_heads, d_ff, max_len, dropout)
            for _ in range(n_layers)
        ])
        
        # Final normalization
        self.final_norm = RMSNorm(d_model)
        
        # Output projection
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Optionally tie weights
        if tie_weights:
            self.lm_head.weight = self.token_embedding.weight
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        """Initialize weights using modern practices."""
        if isinstance(module, nn.Linear):
            # Use Xavier/Glorot initialization
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, RMSNorm):
            nn.init.ones_(module.weight)
    
    def forward(self, input_ids: torch.Tensor, 
                attention_mask: Optional[torch.Tensor] = None,
                position_offset: int = 0) -> torch.Tensor:
        """
        Forward pass of the modern transformer.
        
        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Optional attention mask
            position_offset: Position offset for RoPE
        """
        batch_size, seq_len = input_ids.shape
        
        # Token embeddings (scaled by sqrt(d_model) like in original paper)
        x = self.token_embedding(input_ids) * math.sqrt(self.d_model)
        
        # Create causal mask if not provided
        if attention_mask is None:
            attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device))
            attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
        
        # Pass through transformer layers
        for layer in self.layers:
            x, _ = layer(x, attention_mask, position_offset)
        
        # Final normalization
        x = self.final_norm(x)
        
        # Project to vocabulary
        logits = self.lm_head(x)
        
        return logits
    
    def count_parameters(self) -> dict:
        """Count parameters by component."""
        total = 0
        breakdown = {}
        
        # Token embeddings
        emb_params = sum(p.numel() for p in self.token_embedding.parameters())
        breakdown['token_embedding'] = emb_params
        total += emb_params
        
        # Transformer layers
        layer_params = sum(p.numel() for p in self.layers.parameters())
        breakdown['transformer_layers'] = layer_params
        total += layer_params
        
        # Final norm
        norm_params = sum(p.numel() for p in self.final_norm.parameters())
        breakdown['final_norm'] = norm_params
        total += norm_params
        
        # Output head (if not tied)
        if not hasattr(self, '_tied_weights') or not self._tied_weights:
            head_params = sum(p.numel() for p in self.lm_head.parameters())
            breakdown['lm_head'] = head_params
            total += head_params
        else:
            breakdown['lm_head'] = 0  # Tied with embeddings
        
        breakdown['total'] = total
        return breakdown


def compare_model_architectures():
    """Compare traditional vs modern transformer architecture."""
    # Model configuration
    config = {
        'vocab_size': 50257,
        'd_model': 512,
        'n_layers': 6,
        'n_heads': 8,
        'd_ff': 2048,
        'max_len': 1024,
        'dropout': 0.1
    }
    
    # Create modern model
    modern_model = ModernTransformer(**config).to(device)
    
    # Test input
    batch_size, seq_len = 4, 64
    input_ids = torch.randint(0, config['vocab_size'], (batch_size, seq_len)).to(device)
    
    # Forward pass
    logits = modern_model(input_ids)
    
    # Parameter analysis
    param_breakdown = modern_model.count_parameters()
    
    print("Modern Transformer Architecture Analysis:")
    print("=" * 50)
    print(f"Model Configuration:")
    for key, value in config.items():
        print(f"  {key}: {value:,}")
    
    print(f"\nParameter Breakdown:")
    for component, count in param_breakdown.items():
        percentage = (count / param_breakdown['total']) * 100 if count > 0 else 0
        print(f"  {component}: {count:,} ({percentage:.1f}%)")
    
    print(f"\nOutput Analysis:")
    print(f"  Input shape: {input_ids.shape}")
    print(f"  Output shape: {logits.shape}")
    print(f"  Output range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")
    print(f"  Output std: {logits.std().item():.3f}")
    
    return modern_model, param_breakdown


# Test modern architecture
modern_model, param_breakdown = compare_model_architectures()

# Visualize parameter distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Parameter breakdown pie chart
components = [k for k in param_breakdown.keys() if k != 'total' and param_breakdown[k] > 0]
sizes = [param_breakdown[k] for k in components]
colors = plt.cm.Set3(np.linspace(0, 1, len(components)))

ax1.pie(sizes, labels=components, autopct='%1.1f%%', colors=colors, startangle=90)
ax1.set_title('Parameter Distribution in Modern Transformer')

# Architecture improvements comparison
improvements = ['RMSNorm', 'SwiGLU', 'RoPE', 'Pre-norm']
benefits = ['Parameter\nReduction', 'Better\nGradients', 'Relative\nPositions', 'Training\nStability']
scores = [8, 9, 9, 8]  # Subjective benefit scores

bars = ax2.bar(improvements, scores, color=['lightblue', 'lightgreen', 'lightyellow', 'lightcoral'])
ax2.set_title('Modern Architecture Improvements')
ax2.set_ylabel('Benefit Score (1-10)')
ax2.set_ylim(0, 10)

# Add benefit labels
for bar, benefit in zip(bars, benefits):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.1,
             benefit, ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

print(f"\n🚀 Modern Architecture Summary:")
print(f"• RMSNorm: 50% parameter reduction, better numerical stability")
print(f"• SwiGLU: Gated activation for improved gradient flow")
print(f"• RoPE: Relative positional encoding without extra parameters")
print(f"• Pre-norm: More stable training for deep networks")
print(f"• Combined: State-of-the-art performance and efficiency")

## Summary and Key Takeaways

### 🎯 What We've Learned

1. **RMSNorm vs LayerNorm**:
   - RMSNorm achieves 50% parameter reduction by removing bias terms
   - Maintains high correlation with LayerNorm outputs (>95%)
   - Provides better numerical stability and simpler computation
   - Used in modern models like LLaMA, PaLM, and Chinchilla

2. **SwiGLU Activation**:
   - Combines Swish activation with gating mechanism
   - Provides smoother gradients compared to ReLU/GELU
   - Gating mechanism improves gradient flow and model capacity
   - Requires careful hidden dimension sizing (typically 2/3 * d_ff)

3. **Rotary Position Embedding (RoPE)**:
   - Encodes relative positions through rotation angles
   - No additional parameters required (frequency-based)
   - Naturally handles sequences longer than training length
   - Creates distance-based attention decay patterns
   - Different frequencies for different dimension pairs

4. **Pre-norm vs Post-norm**:
   - Pre-norm provides more stable training dynamics
   - Better gradient flow characteristics for deep networks
   - Post-norm may achieve slightly better final performance
   - Modern models predominantly use pre-norm architecture

5. **Integration Benefits**:
   - Combined improvements provide state-of-the-art performance
   - Better parameter efficiency and training stability
   - Improved scaling properties for large models
   - Foundation for modern LLMs (GPT-3, LLaMA, PaLM)

### 🔬 Practical Implications

- **Memory Efficiency**: RMSNorm reduces parameters, RoPE eliminates positional embeddings
- **Training Stability**: Pre-norm + RMSNorm enables training of very deep networks
- **Sequence Length**: RoPE allows better extrapolation to longer sequences
- **Performance**: SwiGLU consistently outperforms traditional activations

### 🔄 Next Steps

In the upcoming notebooks, we'll explore:
- **Training Optimization** (09): Advanced learning rate schedules, gradient clipping, mixed precision
- **Debugging and Monitoring** (10): Training failure modes, gradient monitoring, troubleshooting
- **Scaling Laws** (11): Chinchilla scaling, compute-optimal training, emergence

### 📚 Further Reading

- **RMSNorm**: Zhang & Sennrich (2019) - "Root Mean Square Layer Normalization"
- **SwiGLU**: Shazeer (2020) - "GLU Variants Improve Transformer"
- **RoPE**: Su et al. (2021) - "RoFormer: Enhanced Transformer with Rotary Position Embedding"
- **Pre-norm**: Xiong et al. (2020) - "On Layer Normalization in the Transformer Architecture"