# Problem 10: Training Instability

Demonstrates gradient issues and solutions like Pre-LayerNorm.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/10_training_instability/demo.ipynb)


In [None]:
!pip install torch matplotlib -q
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Post-LayerNorm (original transformer - unstable)
class PostLNBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attn = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        return self.norm(x + self.attn(x))  # Norm AFTER residual

# Pre-LayerNorm (modern - stable)
class PreLNBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.attn = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        return x + self.attn(self.norm(x))  # Norm BEFORE attention

# Compare gradient flow
def check_gradients(model_class, n_layers=24):
    blocks = nn.Sequential(*[model_class(64) for _ in range(n_layers)])
    x = torch.randn(1, 10, 64, requires_grad=True)
    out = blocks(x)
    out.sum().backward()
    return x.grad.abs().mean().item()

post_grad = check_gradients(PostLNBlock)
pre_grad = check_gradients(PreLNBlock)

print("Gradient magnitude (24 layers):")
print(f"  Post-LayerNorm: {post_grad:.6f}")
print(f"  Pre-LayerNorm:  {pre_grad:.6f}")
print(f"\nâœ“ Pre-LN has {pre_grad/post_grad:.1f}x stronger gradients!")
