# Building Transformer Blocks

In the previous notebook, we explored the attention mechanism. Now we'll see how attention is combined with other components to create complete transformer blocks.

## The Big Picture: Why Do We Need More Than Attention?

Attention is powerful, but it has limitations:

1. **Attention only mixes information** - it's like shuffling cards but not changing their values
2. **No position-wise processing** - each word is processed identically  
3. **Training instability** - deep networks can be hard to train
4. **Information bottlenecks** - gradients can vanish in deep networks

Transformer blocks solve these problems by adding:
- **Feed-Forward Networks** ‚Üí Transform information, not just mix it
- **Layer Normalization** ‚Üí Stabilize training  
- **Residual Connections** ‚Üí Preserve gradient flow

Think of it like this:
- **Attention**: "Let me gather relevant information from other words"
- **Feed-Forward**: "Now let me think about what this information means"
- **Layer Norm**: "Keep everything balanced and stable"
- **Residuals**: "Don't forget what I started with"

## What You'll Learn

1. **Feed-Forward Networks** - The "thinking" component of transformers
2. **Layer Normalization** - Stabilizing training dynamics
3. **Residual Connections** - Enabling deep networks to train
4. **Complete Transformer Block** - How everything fits together
5. **Stacking Blocks** - Building deep transformers

Let's start building!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")

## 1. Feed-Forward Networks: Position-wise Processing

After attention routes information between positions, each position needs **individual processing**. That's where Feed-Forward Networks (FFNs) come in.

### The Core Problem ü§î
- **Attention**: Routes information ("look at relevant context")
- **FFN**: Processes information ("think about what this means")

### FFN Architecture
```
FFN(x) = ReLU(xW‚ÇÅ + b‚ÇÅ)W‚ÇÇ + b‚ÇÇ
```

**Key properties:**
- **Position-wise**: Same transformation applied to each position independently
- **Expand-contract**: `d_model ‚Üí d_ff ‚Üí d_model` (typically 4√ó expansion)  
- **Non-linear**: ReLU enables complex transformations

Think of it like workstations on an assembly line - same tools, different inputs at each position.

In [None]:
# Import the actual FFN implementation we'll use
from src.model.feedforward import FeedForward

# Understanding the dimensions - this is crucial!
print("üß† UNDERSTANDING TRANSFORMER DIMENSIONS")
print("=" * 40)

d_model = 8   # Model dimension (embedding size) - how we represent each word
d_ff = 32     # Feed-forward dimension - internal processing width  

print(f"d_model = {d_model}")
print("  ‚Ü≥ This is how many features each word position has")
print("  ‚Ü≥ Like having 8 attributes to describe each word")
print()
print(f"d_ff = {d_ff}")  
print("  ‚Ü≥ FFN expands to this size for internal processing")
print(f"  ‚Ü≥ {d_ff // d_model}x expansion (standard is 4x)")
print("  ‚Ü≥ More space = more complex transformations")

print("\nüîÑ FFN PROCESSING FLOW:")
print(f"Input:  [seq_len, {d_model}]    # Each position has {d_model} features")
print(f"Expand: [seq_len, {d_ff}]    # Expand to {d_ff} for processing") 
print(f"Output: [seq_len, {d_model}]    # Contract back to {d_model}")

# Now demonstrate with actual data
batch_size, seq_len = 1, 4
ff_net = FeedForward(d_model, d_ff, dropout=0.1)
x = torch.randn(batch_size, seq_len, d_model)

print(f"\nüìä PRACTICAL DEMONSTRATION:")
print(f"Input shape:  {x.shape}")
output = ff_net(x)
print(f"Output shape: {output.shape}")

# Show parameter breakdown
total_params = sum(p.numel() for p in ff_net.parameters())
linear1_params = d_model * d_ff + d_ff  # W1 + b1
linear2_params = d_ff * d_model + d_model  # W2 + b2

print(f"\nüîß PARAMETER BREAKDOWN:")
print(f"Linear1 ({d_model}‚Üí{d_ff}): {linear1_params:,} params")
print(f"Linear2 ({d_ff}‚Üí{d_model}): {linear2_params:,} params")
print(f"Total FFN params:    {total_params:,}")

print(f"\n‚ú® Why this architecture works:")
print(f"‚úÖ Expansion gives more 'thinking space' for each position") 
print(f"‚úÖ Each position processed independently (no cross-talk)")
print(f"‚úÖ Same output size maintains compatibility with attention")

## 2. Layer Normalization: Why We Need It

Before we dive into how Layer Normalization works, let's understand **why** transformers need it.

### The Problem: Training Instability üå™Ô∏è

Imagine training a transformer without normalization:
- **Early training**: Features have small values like `[0.1, 0.2, 0.3]`
- **Later training**: Same features explode to `[100, 200, 300]`  
- **Result**: Gradients become too large or too small ‚Üí training fails

### What Layer Normalization Does üéØ

Layer Norm solves this by **normalizing each position's features**:

```
For each position in the sequence:
1. Calculate mean and std of that position's features
2. Normalize: (features - mean) / std  
3. Apply learnable scale (Œ≥) and shift (Œ≤) parameters
```

### LayerNorm Formula
$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma + \epsilon} + \beta$$

- Œº, œÉ: mean and std **across features for each position**
- Œ≥, Œ≤: learnable parameters (scale and shift)
- Œµ: small value (1e-5) to prevent division by zero

### Key Insight: Position-wise Normalization
- Each position in the sequence is normalized **independently**
- If position 1 has values `[1, 2, 3, 4]` and position 2 has `[100, 200, 300, 400]`
- Both get normalized to have mean‚âà0, std‚âà1 **separately**

Think of it as giving each position a "fresh start" with well-behaved values!

In [None]:
# Let's see LayerNorm in action with a clear, step-by-step example
print("üîß LAYER NORMALIZATION STEP-BY-STEP")
print("=" * 50)

# Create data with problematic scales (this breaks training!)
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0],        # Position 1: small values  
     [100.0, 200.0, 300.0, 400.0]]  # Position 2: huge values!
])
print("üö® PROBLEM: Different positions have wildly different scales!")
print(f"Position 1: {x[0,0].tolist()}")
print(f"  ‚Üí mean={x[0,0].mean():.1f}, std={x[0,0].std():.1f}")
print(f"Position 2: {x[0,1].tolist()}")  
print(f"  ‚Üí mean={x[0,1].mean():.1f}, std={x[0,1].std():.1f}")
print("  ‚Üí These scale differences will break gradient descent!")

print("\n‚ú® SOLUTION: Apply LayerNorm to each position")

# Apply layer normalization
layer_norm = nn.LayerNorm(4)  # Normalize across the 4 features
x_normalized = layer_norm(x)

print("After LayerNorm:")
print(f"Position 1: {[round(val, 3) for val in x_normalized[0,0].tolist()]}")
print(f"  ‚Üí mean={x_normalized[0,0].mean():.3f}, std={x_normalized[0,0].std():.3f}")
print(f"Position 2: {[round(val, 3) for val in x_normalized[0,1].tolist()]}")
print(f"  ‚Üí mean={x_normalized[0,1].mean():.3f}, std={x_normalized[0,1].std():.3f}")

# Show what LayerNorm learned
print(f"\nüéõÔ∏è  LayerNorm learned parameters:")
print(f"Scale (Œ≥): {layer_norm.weight.tolist()}")
print(f"Shift (Œ≤): {layer_norm.bias.tolist()}")

print(f"\n‚úÖ SUCCESS: LayerNorm fixed the scale problem!")
print(f"‚Ä¢ Both positions now have mean‚âà0, std‚âà1")
print(f"‚Ä¢ Gradients can flow properly during training")
print(f"‚Ä¢ Each position normalized independently")

# Manual calculation to show how it works
print(f"\nüîç HOW IT WORKS (Position 1 example):")
pos1_original = x[0, 0]
pos1_mean = pos1_original.mean()
pos1_std = pos1_original.std()
pos1_normalized = (pos1_original - pos1_mean) / pos1_std
print(f"Original: {pos1_original.tolist()}")
print(f"Mean: {pos1_mean:.1f}, Std: {pos1_std:.1f}")
print(f"(x - mean) / std: {[round(val.item(), 3) for val in pos1_normalized]}")
print(f"Final (with Œ≥,Œ≤): {[round(val, 3) for val in x_normalized[0,0].tolist()]}")

## 3. Residual Connections: Gradient Highways

Deep networks suffer from **vanishing gradients** - signals become weaker as they pass through many layers. Residual connections solve this.

### The Solution: Skip Connections üõ£Ô∏è
Instead of `output = f(x)`, use:
$$\text{output} = x + f(x)$$

**Why this works:**
- Gradient flows directly through the `+ x` path (always = 1)  
- Even if `‚àáf(x)` vanishes, gradients still flow via the skip connection
- Like having highway and local roads for traffic

### Pre-norm vs Post-norm Architecture
- **Post-norm**: `LayerNorm(x + sublayer(x))` (original)
- **Pre-norm**: `x + sublayer(LayerNorm(x))` (modern, more stable)

We'll use pre-norm because it's more stable for deep networks.

In [None]:
# Quick demonstration of residual connections
print("üõ£Ô∏è  RESIDUAL CONNECTION DEMONSTRATION")

x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"Input:           {x.squeeze().tolist()}")

# Simulate a layer that might cause vanishing gradients
weak_transform = nn.Linear(4, 4)
with torch.no_grad():
    weak_transform.weight.fill_(0.01)  # Very small weights
    weak_transform.bias.zero_()

# Without residual connection
output_no_res = weak_transform(x)
print(f"Without residual: {output_no_res.squeeze().tolist()!r} (signal lost!)")

# With residual connection  
output_with_res = x + weak_transform(x)
print(f"With residual:    {output_with_res.squeeze().tolist()!r} (signal preserved!)")

print(f"\n‚ú® Residual connections preserve the original signal!")
print(f"‚úÖ Enable training of very deep networks")
print(f"‚úÖ Gradient highways prevent vanishing gradients")

## 4. Complete Transformer Block: Putting It All Together

Now we can assemble a complete transformer block using all three components:

1. **Multi-Head Attention** (from notebook 1) + Layer Norm + Residual
2. **Feed-Forward Network** + Layer Norm + Residual

### Pre-norm Architecture (Modern Standard)
```python
# Step 1: Attention with pre-norm
normed = LayerNorm(x)
attention_out = MultiHeadAttention(normed)
x = x + attention_out  # Residual connection

# Step 2: Feed-forward with pre-norm  
normed = LayerNorm(x)
ff_out = FeedForward(normed)
x = x + ff_out  # Residual connection
```

This creates a stable, trainable building block that we can stack into deep networks.

In [None]:
from src.model.attention import MultiHeadAttention
from src.model.feedforward import FeedForward

class TransformerBlock(nn.Module):
    """Complete transformer block with pre-norm architecture."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        # Core components
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # Normalization layers  
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """Pre-norm transformer block forward pass."""
        
        # Step 1: Multi-head attention with pre-norm and residual
        normed = self.norm1(x)
        attn_out = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_out)
        
        # Step 2: Feed-forward with pre-norm and residual
        normed = self.norm2(x)
        ff_out = self.feed_forward(normed)
        x = x + self.dropout(ff_out)
        
        return x

# Test the complete transformer block
print("üèóÔ∏è  COMPLETE TRANSFORMER BLOCK")
d_model, n_heads, d_ff = 8, 2, 32
block = TransformerBlock(d_model, n_heads, d_ff)

# Forward pass
x = torch.randn(1, 4, d_model)
output = block(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")

# Analyze parameter distribution
attention_params = sum(p.numel() for p in block.attention.parameters())
ff_params = sum(p.numel() for p in block.feed_forward.parameters())
norm_params = sum(p.numel() for p in block.norm1.parameters()) + sum(p.numel() for p in block.norm2.parameters())
total_params = attention_params + ff_params + norm_params

print(f"\nüìä PARAMETER BREAKDOWN:")
print(f"Attention:     {attention_params:,} ({attention_params/total_params*100:.1f}%)")
print(f"Feed-forward:  {ff_params:,} ({ff_params/total_params*100:.1f}%)")  
print(f"Layer norms:   {norm_params:,} ({norm_params/total_params*100:.1f}%)")
print(f"Total:         {total_params:,}")

print(f"\n‚ú® The transformer block successfully combines all components!")
print(f"‚úÖ Attention routes information between positions")
print(f"‚úÖ FFN processes each position independently") 
print(f"‚úÖ LayerNorm provides training stability")
print(f"‚úÖ Residuals enable deep network training")

## 5. Stacking Transformer Blocks

The power of transformers comes from stacking multiple blocks. Each block can learn different types of patterns and relationships. Let's see how information flows through a stack of blocks.

In [None]:
# Create a simple stacked transformer for clear demonstration
class SimpleTransformer(nn.Module):
    """Simple transformer with multiple blocks for analysis."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.n_layers = n_layers
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # Store representations at each layer for analysis
        layer_outputs = [x.detach().clone()]
        
        for i, block in enumerate(self.blocks):
            x = block(x)
            layer_outputs.append(x.detach().clone())
            print(f"After Layer {i+1}: mean={x.mean().item():.3f}, std={x.std().item():.3f}")
        
        # Final normalization
        x = self.final_norm(x)
        layer_outputs.append(x.detach().clone())
        print(f"After Final Norm: mean={x.mean().item():.3f}, std={x.std().item():.3f}")
        
        return x, layer_outputs

print("üìö STACKING TRANSFORMER BLOCKS")
print("=" * 40)

# Create a 3-layer transformer with clear dimensions
d_model, n_heads, d_ff = 8, 2, 32  # Same dimensions we explained earlier
n_layers = 3

transformer = SimpleTransformer(n_layers, d_model, n_heads, d_ff)

# Create input representing a sequence of 4 tokens
x = torch.randn(1, 4, d_model)  # [batch=1, seq_len=4, d_model=8]

print(f"Input shape: {x.shape}")
print(f"d_model: {d_model}, n_heads: {n_heads}, d_ff: {d_ff}")
print(f"Number of layers: {n_layers}")

# Count parameters
total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"\nüìä PARAMETER ANALYSIS:")
print(f"Total parameters: {total_params:,}")
print(f"Parameters per layer: {params_per_layer:,}")

# Forward pass with monitoring
print(f"\nüîÑ FORWARD PASS THROUGH ALL LAYERS:")
print(f"Input: mean={x.mean().item():.3f}, std={x.std().item():.3f}")

output, layer_outputs = transformer(x)

print(f"\n‚ú® WHAT EACH LAYER LEARNS:")
print(f"‚Ä¢ Layer 1: Basic feature combinations and simple attention patterns")
print(f"‚Ä¢ Layer 2: More complex interactions between positions")
print(f"‚Ä¢ Layer 3: High-level reasoning and abstract relationships")

print(f"\nüîë KEY OBSERVATIONS:")
print(f"‚úÖ Each layer transforms the representation differently")
print(f"‚úÖ LayerNorm keeps values well-behaved throughout")
print(f"‚úÖ Residual connections preserve important information")
print(f"‚úÖ The network can learn increasingly complex patterns")

# Show how the representation changes
print(f"\nüìà REPRESENTATION EVOLUTION:")
for i, layer_out in enumerate(layer_outputs):
    magnitude = torch.norm(layer_out).item()
    if i == 0:
        print(f"Input:        magnitude={magnitude:.2f}")
    elif i <= n_layers:
        print(f"Layer {i}:      magnitude={magnitude:.2f}")
    else:
        print(f"Final output: magnitude={magnitude:.2f}")

In [None]:
# Create a simple multi-layer transformer for demonstration
class SimpleTransformer(nn.Module):
    """Stack of transformer blocks."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacking
print("üìö STACKING TRANSFORMER BLOCKS")
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"Layers:           {n_layers}")
print(f"Input shape:      {x.shape}")
print(f"Output shape:     {output.shape}")
print(f"Total parameters: {total_params:,}")
print(f"Per layer:        {params_per_layer:,}")

print(f"\n‚ú® Each layer can learn different patterns:")
print(f"‚Ä¢ Layer 1: Basic features and attention patterns")
print(f"‚Ä¢ Layer 2: More complex relationships") 
print(f"‚Ä¢ Layer 3: High-level abstractions and reasoning")

print(f"\nüîë KEY INSIGHT: Deep networks can learn hierarchical representations!")
print(f"‚úÖ Residual connections make deep stacking possible")
print(f"‚úÖ Each layer builds on previous layers' understanding")