# Building Transformer Blocks

In the previous notebook, we explored the attention mechanism. Now we'll see how attention is combined with other components to create complete transformer blocks.

## The Big Picture: Why Do We Need More Than Attention?

Attention is powerful, but it has limitations:

1. **Attention only mixes information** - it's like shuffling cards but not changing their values
2. **No position-wise processing** - each word is processed identically  
3. **Training instability** - deep networks can be hard to train
4. **Information bottlenecks** - gradients can vanish in deep networks

Transformer blocks solve these problems by adding:
- **Feed-Forward Networks** → Transform information, not just mix it
- **Layer Normalization** → Stabilize training  
- **Residual Connections** → Preserve gradient flow

Think of it like this:
- **Attention**: "Let me gather relevant information from other words"
- **Feed-Forward**: "Now let me think about what this information means"
- **Layer Norm**: "Keep everything balanced and stable"
- **Residuals**: "Don't forget what I started with"

## What You'll Learn

1. **Feed-Forward Networks** - The "thinking" component of transformers
2. **Layer Normalization** - Stabilizing training dynamics
3. **Residual Connections** - Enabling deep networks to train
4. **Complete Transformer Block** - How everything fits together
5. **Stacking Blocks** - Building deep transformers

Let's start building!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")

## 1. Feed-Forward Networks: Position-wise Processing

After attention routes information between positions, each position needs **individual processing**. That's where Feed-Forward Networks (FFNs) come in.

### The Core Problem 🤔
- **Attention**: Routes information ("look at relevant context")
- **FFN**: Processes information ("think about what this means")

### FFN Architecture
```
FFN(x) = ReLU(xW₁ + b₁)W₂ + b₂
```

**Key properties:**
- **Position-wise**: Same transformation applied to each position independently
- **Expand-contract**: `d_model → d_ff → d_model` (typically 4× expansion)  
- **Non-linear**: ReLU enables complex transformations

Think of it like workstations on an assembly line - same tools, different inputs at each position.

In [None]:
# Import the actual FFN implementation we'll use
from src.model.feedforward import FeedForward

# Create and test feed-forward network  
d_model, d_ff = 8, 32  # 4x expansion
batch_size, seq_len = 1, 4

ff_net = FeedForward(d_model, d_ff, dropout=0.1)
x = torch.randn(batch_size, seq_len, d_model)

print(f"📊 FFN DEMONSTRATION")
print(f"Input shape:  {x.shape}")
print(f"d_model → d_ff: {d_model} → {d_ff} (4x expansion)")

# Forward pass
output = ff_net(x)
print(f"Output shape: {output.shape}")

# Parameter analysis
total_params = sum(p.numel() for p in ff_net.parameters())
linear1_params = d_model * d_ff + d_ff  # weights + bias
linear2_params = d_ff * d_model + d_model  # weights + bias

print(f"\n🔧 PARAMETER BREAKDOWN:")
print(f"Linear 1 (expand):   {linear1_params:,}")
print(f"Linear 2 (contract): {linear2_params:,}") 
print(f"Total FFN params:    {total_params:,}")
print(f"\n✨ The FFN transforms each position's representation independently!")

## 2. Layer Normalization: Training Stability

Deep networks can be unstable during training due to internal covariate shift. Layer normalization solves this by normalizing each position's features.

### Why Layer Norm? 🎯
- **Problem**: Feature values can vary wildly across layers and training steps
- **Solution**: Normalize features to have mean=0, std=1 for each position
- **Benefit**: Stable gradients and faster convergence

### LayerNorm Formula
$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta$$

Where μ and σ are computed across the feature dimension for each position independently.

In [None]:
# Demonstrate LayerNorm effect with a clear example
print("🔧 LAYER NORMALIZATION DEMONSTRATION")

# Create data with different scales (this causes training instability)
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0],      # Position 1: small values  
     [100.0, 200.0, 300.0, 400.0]], # Position 2: large values
])

print("Before LayerNorm:")
print(f"Position 1: {x[0,0].tolist()} (mean={x[0,0].mean():.1f}, std={x[0,0].std():.1f})")
print(f"Position 2: {x[0,1].tolist()} (mean={x[0,1].mean():.1f}, std={x[0,1].std():.1f})")

# Apply layer normalization
layer_norm = nn.LayerNorm(4)
x_normalized = layer_norm(x)

print(f"\nAfter LayerNorm:")
print(f"Position 1: {x_normalized[0,0].tolist()!r} (mean={x_normalized[0,0].mean():.3f})")
print(f"Position 2: {x_normalized[0,1].tolist()!r} (mean={x_normalized[0,1].mean():.3f})")

print(f"\n✨ LayerNorm normalizes each position independently!")
print(f"✅ Mean ≈ 0, Std ≈ 1 for both positions")
print(f"✅ Removes scale differences that hurt training")

## 3. Residual Connections: Gradient Highways

Deep networks suffer from **vanishing gradients** - signals become weaker as they pass through many layers. Residual connections solve this.

### The Solution: Skip Connections 🛣️
Instead of `output = f(x)`, use:
$$\text{output} = x + f(x)$$

**Why this works:**
- Gradient flows directly through the `+ x` path (always = 1)  
- Even if `∇f(x)` vanishes, gradients still flow via the skip connection
- Like having highway and local roads for traffic

### Pre-norm vs Post-norm Architecture
- **Post-norm**: `LayerNorm(x + sublayer(x))` (original)
- **Pre-norm**: `x + sublayer(LayerNorm(x))` (modern, more stable)

We'll use pre-norm because it's more stable for deep networks.

In [None]:
# Quick demonstration of residual connections
print("🛣️  RESIDUAL CONNECTION DEMONSTRATION")

x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"Input:           {x.squeeze().tolist()}")

# Simulate a layer that might cause vanishing gradients
weak_transform = nn.Linear(4, 4)
with torch.no_grad():
    weak_transform.weight.fill_(0.01)  # Very small weights
    weak_transform.bias.zero_()

# Without residual connection
output_no_res = weak_transform(x)
print(f"Without residual: {output_no_res.squeeze().tolist()!r} (signal lost!)")

# With residual connection  
output_with_res = x + weak_transform(x)
print(f"With residual:    {output_with_res.squeeze().tolist()!r} (signal preserved!)")

print(f"\n✨ Residual connections preserve the original signal!")
print(f"✅ Enable training of very deep networks")
print(f"✅ Gradient highways prevent vanishing gradients")

## 4. Complete Transformer Block: Putting It All Together

Now we can assemble a complete transformer block using all three components:

1. **Multi-Head Attention** (from notebook 1) + Layer Norm + Residual
2. **Feed-Forward Network** + Layer Norm + Residual

### Pre-norm Architecture (Modern Standard)
```python
# Step 1: Attention with pre-norm
normed = LayerNorm(x)
attention_out = MultiHeadAttention(normed)
x = x + attention_out  # Residual connection

# Step 2: Feed-forward with pre-norm  
normed = LayerNorm(x)
ff_out = FeedForward(normed)
x = x + ff_out  # Residual connection
```

This creates a stable, trainable building block that we can stack into deep networks.

In [None]:
from src.model.attention import MultiHeadAttention
from src.model.feedforward import FeedForward

class TransformerBlock(nn.Module):
    """Complete transformer block with pre-norm architecture."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        # Core components
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # Normalization layers  
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """Pre-norm transformer block forward pass."""
        
        # Step 1: Multi-head attention with pre-norm and residual
        normed = self.norm1(x)
        attn_out = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_out)
        
        # Step 2: Feed-forward with pre-norm and residual
        normed = self.norm2(x)
        ff_out = self.feed_forward(normed)
        x = x + self.dropout(ff_out)
        
        return x

# Test the complete transformer block
print("🏗️  COMPLETE TRANSFORMER BLOCK")
d_model, n_heads, d_ff = 8, 2, 32
block = TransformerBlock(d_model, n_heads, d_ff)

# Forward pass
x = torch.randn(1, 4, d_model)
output = block(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")

# Analyze parameter distribution
attention_params = sum(p.numel() for p in block.attention.parameters())
ff_params = sum(p.numel() for p in block.feed_forward.parameters())
norm_params = sum(p.numel() for p in block.norm1.parameters()) + sum(p.numel() for p in block.norm2.parameters())
total_params = attention_params + ff_params + norm_params

print(f"\n📊 PARAMETER BREAKDOWN:")
print(f"Attention:     {attention_params:,} ({attention_params/total_params*100:.1f}%)")
print(f"Feed-forward:  {ff_params:,} ({ff_params/total_params*100:.1f}%)")  
print(f"Layer norms:   {norm_params:,} ({norm_params/total_params*100:.1f}%)")
print(f"Total:         {total_params:,}")

print(f"\n✨ The transformer block successfully combines all components!")
print(f"✅ Attention routes information between positions")
print(f"✅ FFN processes each position independently") 
print(f"✅ LayerNorm provides training stability")
print(f"✅ Residuals enable deep network training")

## 5. Stacking Transformer Blocks

The power of transformers comes from stacking multiple blocks. Each block can learn different types of patterns and relationships. Let's see how information flows through a stack of blocks.

In [None]:
class SimpleTransformer(nn.Module):
    """Simple transformer with multiple blocks for analysis."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.n_layers = n_layers
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x, return_all_attention=False):
        """
        Forward pass through all transformer blocks.
        """
        all_attention_weights = []
        layer_outputs = [x.clone()]  # Store output from each layer
        
        for i, block in enumerate(self.blocks):
            if return_all_attention:
                x, attention_weights, _ = block(x, return_attention=True)
                all_attention_weights.append(attention_weights)
            else:
                x, _ = block(x)
            
            layer_outputs.append(x.clone())
        
        # Final layer normalization
        x = self.final_norm(x)
        layer_outputs.append(x.clone())
        
        if return_all_attention:
            return x, all_attention_weights, layer_outputs
        return x, layer_outputs

# Create a 3-layer transformer
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

# Create input
x = torch.randn(1, 4, 8)

print(f"Transformer with {n_layers} layers")
print(f"Total parameters: {sum(p.numel() for p in transformer.parameters()):,}")

# Forward pass
output, all_attention, layer_outputs = transformer(x, return_all_attention=True)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of attention matrices: {len(all_attention)}")
print(f"Number of layer outputs: {len(layer_outputs)}")

# Analyze how representations change through layers
print("\nRepresentation analysis through layers:")
for i, layer_output in enumerate(layer_outputs):
    norm = torch.norm(layer_output).item()
    mean = layer_output.mean().item()
    std = layer_output.std().item()
    
    if i == 0:
        layer_name = "Input"
    elif i <= n_layers:
        layer_name = f"Layer {i}"
    else:
        layer_name = "Final Norm"
    
    print(f"{layer_name:12}: norm={norm:6.3f}, mean={mean:6.3f}, std={std:6.3f}")

# Visualize attention patterns across layers
fig, axes = plt.subplots(1, n_layers, figsize=(15, 4))

for layer_idx in range(n_layers):
    # Average attention across heads
    avg_attention = all_attention[layer_idx][0].mean(dim=0).detach().numpy()
    
    sns.heatmap(
        avg_attention,
        annot=True, fmt='.2f',
        cmap='Blues',
        ax=axes[layer_idx],
        cbar=layer_idx == n_layers - 1
    )
    axes[layer_idx].set_title(f'Layer {layer_idx + 1}\nAttention')
    axes[layer_idx].set_xlabel('Keys')
    if layer_idx == 0:
        axes[layer_idx].set_ylabel('Queries')

plt.tight_layout()
plt.show()

print("\nNotice how different layers learn different attention patterns!")

## 5. Scaling Up: Stacking Blocks

The power of transformers comes from stacking multiple blocks. Each layer learns increasingly complex patterns.

In [None]:
# Create a simple multi-layer transformer for demonstration
class SimpleTransformer(nn.Module):
    """Stack of transformer blocks."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacking
print("📚 STACKING TRANSFORMER BLOCKS")
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"Layers:           {n_layers}")
print(f"Input shape:      {x.shape}")
print(f"Output shape:     {output.shape}")
print(f"Total parameters: {total_params:,}")
print(f"Per layer:        {params_per_layer:,}")

print(f"\n✨ Each layer can learn different patterns:")
print(f"• Layer 1: Basic features and attention patterns")
print(f"• Layer 2: More complex relationships") 
print(f"• Layer 3: High-level abstractions and reasoning")

print(f"\n🔑 KEY INSIGHT: Deep networks can learn hierarchical representations!")
print(f"✅ Residual connections make deep stacking possible")
print(f"✅ Each layer builds on previous layers' understanding")

## Summary: Building Blocks of Transformers

We've learned how to build complete transformer blocks by combining four essential components:

### The Architecture Stack 📚
1. **Feed-Forward Networks** → Process each position independently  
2. **Layer Normalization** → Stabilize training dynamics
3. **Residual Connections** → Enable deep network training via gradient highways
4. **Multi-Head Attention** → Route information between positions

### Key Design Choices ⚙️
- **Pre-norm architecture**: More stable than post-norm for deep networks
- **4x FFN expansion**: Standard ratio for d_ff = 4 × d_model  
- **Parameter distribution**: ~67% in FFN, ~33% in attention, <1% in norms

### The Big Picture 🎯
Each transformer block performs two main operations:
1. **Communication**: Attention mixes information between positions
2. **Computation**: FFN processes each position's updated representation

Stacking multiple blocks creates increasingly sophisticated representations, enabled by residual connections that maintain gradient flow through arbitrary depth.

---

**Next**: We'll explore how transformers understand word order through positional encoding!