# Module 1.3: Feed-Forward Networks & Normalization

**Goal**: Understand FFN architectures and normalization methods

**Time**: 45 minutes

**Concepts Covered**:
- SwiGLU activation function
- ReLU vs GELU vs SwiGLU comparison
- LayerNorm vs RMSNorm
- Memory profiling

## Setup

In [None]:
!pip install torch numpy matplotlib seaborn -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time

torch.manual_seed(42)
np.random.seed(42)
plt.style.use('seaborn-v0_8-darkgrid')

## Lesson 1: Activation Functions (15 mins)

Compare ReLU, GELU, and SwiGLU activation functions.

In [None]:
# Define activation functions
def relu(x):
    return F.relu(x)

def gelu(x):
    return F.gelu(x)

def swiglu(x):
    """SwiGLU: Swish-Gated Linear Unit
    SwiGLU(x) = Swish(xW + b) ⊙ (xV + c)
    where Swish(x) = x * sigmoid(x)
    """
    x1, x2 = x.chunk(2, dim=-1)
    return F.silu(x1) * x2  # SiLU is Swish

# Visualize activations
x = torch.linspace(-5, 5, 1000)
y_relu = relu(x)
y_gelu = gelu(x)
# For SwiGLU, we need two inputs, so we'll show Swish
y_swish = F.silu(x)

plt.figure(figsize=(12, 5))
plt.plot(x.numpy(), y_relu.numpy(), label='ReLU', linewidth=2)
plt.plot(x.numpy(), y_gelu.numpy(), label='GELU', linewidth=2)
plt.plot(x.numpy(), y_swish.numpy(), label='Swish (SwiGLU component)', linewidth=2)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Activation Functions Comparison', fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Lesson 2: Feed-Forward Networks (15 mins)

Implement FFN with different activation functions.

In [None]:
class FFN_ReLU(nn.Module):
    """Standard FFN with ReLU"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.activation = F.relu
    
    def forward(self, x):
        return self.w2(self.activation(self.w1(x)))

class FFN_GELU(nn.Module):
    """FFN with GELU (used in GPT, BERT)"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.activation = F.gelu
    
    def forward(self, x):
        return self.w2(self.activation(self.w1(x)))

class FFN_SwiGLU(nn.Module):
    """FFN with SwiGLU (used in PaLM, LLaMA)"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        # SwiGLU uses 2/3 * d_ff for gate projection
        self.w1 = nn.Linear(d_model, d_ff * 2)  # Gate and up projection
        self.w2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        x1, x2 = self.w1(x).chunk(2, dim=-1)
        return self.w2(F.silu(x1) * x2)

# Test all three
d_model = 128
d_ff = 512
batch_size = 2
seq_len = 10

x = torch.randn(batch_size, seq_len, d_model)

ffn_relu = FFN_ReLU(d_model, d_ff)
ffn_gelu = FFN_GELU(d_model, d_ff)
ffn_swiglu = FFN_SwiGLU(d_model, int(d_ff * 2/3))  # SwiGLU uses 2/3 size

out_relu = ffn_relu(x)
out_gelu = ffn_gelu(x)
out_swiglu = ffn_swiglu(x)

print(f"Input shape: {x.shape}")
print(f"ReLU FFN output: {out_relu.shape}")
print(f"GELU FFN output: {out_gelu.shape}")
print(f"SwiGLU FFN output: {out_swiglu.shape}")

## Lesson 3: Normalization Layers (15 mins)

Compare LayerNorm and RMSNorm.

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (used in LLaMA)"""
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        # Compute RMS
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        # Normalize and scale
        return self.weight * (x / rms)

class LayerNorm(nn.Module):
    """Standard Layer Normalization"""
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, eps=eps)
    
    def forward(self, x):
        return self.norm(x)

# Compare
d_model = 128
batch_size = 2
seq_len = 10

x = torch.randn(batch_size, seq_len, d_model)

layer_norm = LayerNorm(d_model)
rms_norm = RMSNorm(d_model)

out_ln = layer_norm(x)
out_rms = rms_norm(x)

print(f"Input shape: {x.shape}")
print(f"LayerNorm output shape: {out_ln.shape}")
print(f"RMSNorm output shape: {out_rms.shape}")
print(f"\nLayerNorm mean: {out_ln.mean().item():.6f}, std: {out_ln.std().item():.6f}")
print(f"RMSNorm mean: {out_rms.mean().item():.6f}, std: {out_rms.std().item():.6f}")

In [None]:
# Memory profiling comparison
import torch

def profile_memory(model, x, name):
    if torch.cuda.is_available():
        model = model.cuda()
        x = x.cuda()
        torch.cuda.reset_peak_memory_stats()
        _ = model(x)
        memory_mb = torch.cuda.max_memory_allocated() / 1024**2
        return memory_mb
    return 0

d_model = 256
d_ff = 1024
batch_size = 4
seq_len = 32

x = torch.randn(batch_size, seq_len, d_model)

models = {
    "ReLU FFN": FFN_ReLU(d_model, d_ff),
    "GELU FFN": FFN_GELU(d_model, d_ff),
    "SwiGLU FFN": FFN_SwiGLU(d_model, int(d_ff * 2/3)),
}

print("Memory Usage (if CUDA available):")
for name, model in models.items():
    mem = profile_memory(model, x, name)
    if mem > 0:
        print(f"  {name}: {mem:.2f} MB")

## Key Takeaways

✅ **SwiGLU**: More expressive than ReLU/GELU, used in modern models (LLaMA, PaLM)

✅ **RMSNorm**: Simpler than LayerNorm (no mean centering), faster and more efficient

✅ **FFN Architecture**: Typically 4x expansion (d_ff = 4 * d_model)

✅ **Memory**: SwiGLU uses 2/3 * d_ff for gate, but often outperforms standard FFN

## Next Steps

Continue to **Module 1.4: Complete Transformer Block** to build a full transformer layer.