# Building Transformer Blocks

In the previous notebook, we explored the attention mechanism. Now we'll see how attention is combined with other components to create complete transformer blocks.

## The Big Picture: Why Do We Need More Than Attention?

Attention is powerful, but it has limitations:

1. **Attention only mixes information** - it's like shuffling cards but not changing their values
2. **No position-wise processing** - each word is processed identically  
3. **Training instability** - deep networks can be hard to train
4. **Information bottlenecks** - gradients can vanish in deep networks

Transformer blocks solve these problems by adding:
- **Feed-Forward Networks** → Transform information, not just mix it
- **Layer Normalization** → Stabilize training  
- **Residual Connections** → Preserve gradient flow

Think of it like this:
- **Attention**: "Let me gather relevant information from other words"
- **Feed-Forward**: "Now let me think about what this information means"
- **Layer Norm**: "Keep everything balanced and stable"
- **Residuals**: "Don't forget what I started with"

## What You'll Learn

1. **Feed-Forward Networks** - The "thinking" component of transformers
2. **Layer Normalization** - Stabilizing training dynamics
3. **Residual Connections** - Enabling deep networks to train
4. **Complete Transformer Block** - How everything fits together
5. **Stacking Blocks** - Building deep transformers

Let's start building!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")

## 1. Feed-Forward Networks: The "Thinking" Component

After attention gathers relevant information from other positions, we need to **process** that information. That's where Feed-Forward Networks (FFNs) come in.

### Why Do We Need FFNs? 🤔

Attention is great at **routing information** but it can't **transform** it:
- Attention: "The word 'bank' should look at 'river' and 'loans'"  
- FFN: "Based on seeing 'river', this 'bank' means 'riverside'" 

### The FFN Formula
FFNs apply the same transformation to each position independently:

$$\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

**Key insights:**
- **Position-wise**: Each word gets the same transformation (but with different inputs)
- **Non-linear**: ReLU allows complex transformations
- **Expand-contract**: Typically `d_ff = 4 × d_model` for more expressivity

### The Factory Analogy 🏭
Think of FFNs like assembly line stations:
- Each position is a workstation
- Same tools (weights) at each station  
- Each workstation processes different items (word representations)
- First layer **expands** the representation (more features)
- Second layer **contracts** back to original size

In [None]:
class SimpleFeedForward(nn.Module):
    """Simple feed-forward network with visualization."""
    
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x, return_intermediate=False):
        # x shape: [batch_size, seq_len, d_model]
        
        # First linear transformation
        hidden = self.linear1(x)  # [batch_size, seq_len, d_ff]
        
        # ReLU activation
        activated = self.relu(hidden)  # [batch_size, seq_len, d_ff]
        
        # Second linear transformation back to d_model
        output = self.linear2(activated)  # [batch_size, seq_len, d_model]
        
        if return_intermediate:
            return output, {'hidden': hidden, 'activated': activated}
        return output

# Create and test feed-forward network
d_model, d_ff = 8, 32  # 4x expansion
batch_size, seq_len = 2, 4

ff_net = SimpleFeedForward(d_model, d_ff)

# Create input
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input shape: {x.shape}")

# Forward pass with intermediate values
output, intermediates = ff_net(x, return_intermediate=True)

print(f"Hidden shape (after linear1): {intermediates['hidden'].shape}")
print(f"Activated shape (after ReLU): {intermediates['activated'].shape}")
print(f"Output shape (after linear2): {output.shape}")

# Count parameters
total_params = sum(p.numel() for p in ff_net.parameters())
print(f"\nFeed-forward parameters: {total_params:,}")
print(f"  Linear1: {d_model * d_ff + d_ff:,} (weights + bias)")
print(f"  Linear2: {d_ff * d_model + d_model:,} (weights + bias)")

## 2. Layer Normalization

Layer normalization stabilizes training by normalizing inputs to each layer. Unlike batch normalization, it normalizes across the feature dimension for each individual example.

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sigma} + \beta$$

Where $\mu$ and $\sigma$ are the mean and standard deviation across the feature dimension.

In [None]:
def demonstrate_layer_norm():
    """Demonstrate how layer normalization works."""
    
    # Create example data with different scales
    batch_size, seq_len, d_model = 2, 3, 4
    
    # Create data where different examples have different scales
    x = torch.tensor([
        [[1.0, 2.0, 3.0, 4.0],    # Example 1, position 1
         [10.0, 20.0, 30.0, 40.0], # Example 1, position 2  
         [0.1, 0.2, 0.3, 0.4]],   # Example 1, position 3
        
        [[100.0, 200.0, 300.0, 400.0],  # Example 2, position 1
         [5.0, 6.0, 7.0, 8.0],          # Example 2, position 2
         [0.01, 0.02, 0.03, 0.04]]      # Example 2, position 3
    ])
    
    print("Original data:")
    print(f"Shape: {x.shape}")
    print("Example 1:")
    print(x[0])
    print("Example 2:")
    print(x[1])
    
    # Apply layer normalization
    layer_norm = nn.LayerNorm(d_model)
    x_normalized = layer_norm(x)
    
    print("\nAfter Layer Normalization:")
    print("Example 1:")
    print(x_normalized[0])
    print("Example 2:")
    print(x_normalized[1])
    
    # Check normalization properties
    print("\nNormalization properties:")
    for i in range(batch_size):
        for j in range(seq_len):
            mean = x_normalized[i, j].mean()
            std = x_normalized[i, j].std()
            print(f"Example {i+1}, Position {j+1}: mean={mean:.6f}, std={std:.6f}")
    
    return x, x_normalized

# Demonstrate layer normalization
original, normalized = demonstrate_layer_norm()

# Visualize the effect
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Original data
im1 = ax1.imshow(original.view(-1, original.size(-1)).numpy(), cmap='viridis', aspect='auto')
ax1.set_title('Original Data')
ax1.set_xlabel('Feature Dimension')
ax1.set_ylabel('Batch × Sequence Position')
plt.colorbar(im1, ax=ax1)

# Normalized data
im2 = ax2.imshow(normalized.view(-1, normalized.size(-1)).detach().numpy(), cmap='viridis', aspect='auto')
ax2.set_title('After Layer Normalization')
ax2.set_xlabel('Feature Dimension')
ax2.set_ylabel('Batch × Sequence Position')
plt.colorbar(im2, ax=ax2)

plt.tight_layout()
plt.show()

print("\nNotice how layer norm brings all values to a similar scale!")

## 3. Residual Connections: The Gradient Highway

Residual connections are one of the most important innovations in deep learning. They solve the fundamental problem of **vanishing gradients** in deep networks.

### The Problem: Vanishing Gradients 📉
Without residual connections, gradients get smaller and smaller as they flow backwards through layers:
- Layer 10: gradient = 1.0
- Layer 5: gradient = 0.1  
- Layer 1: gradient = 0.001 → can't learn!

### The Solution: Gradient Highways 🛣️
Residual connections create "highways" for gradients:

$$\text{output} = x + \text{Sublayer}(x)$$

**Why this works:**
- The gradient of `x + f(x)` includes both `∇f(x)` and `1` (from the identity)
- Even if `∇f(x)` vanishes, the `1` ensures gradients flow back
- It's like having both local roads (sublayer) and highways (residual) for traffic

### Architecture Choices: Pre-norm vs Post-norm

**Post-norm (original)**: `LayerNorm(x + Sublayer(x))`  
**Pre-norm (modern)**: `x + Sublayer(LayerNorm(x))`

Pre-norm is more stable because:
- Normalization happens before potentially destabilizing operations
- Direct path from output to input preserves gradients better

### The Highway Analogy 🚗
Think of residuals like highway systems:
- **Local roads** (sublayers): Can get congested or blocked
- **Highway** (residual): Always provides a direct route
- **Traffic** (gradients): Can always flow even if local roads are slow

In [None]:
def demonstrate_residual_connections():
    """Show why residual connections are important."""
    
    d_model = 4
    x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])  # [1, d_model]
    
    # Simulate a transformation that might hurt the signal
    # (e.g., a poorly initialized layer)
    transformation = nn.Linear(d_model, d_model)
    
    # Make weights very small to simulate vanishing gradients
    with torch.no_grad():
        transformation.weight.fill_(0.01)
        transformation.bias.fill_(0.0)
    
    # Without residual connection
    output_no_residual = transformation(x)
    
    # With residual connection
    output_with_residual = x + transformation(x)
    
    print("Demonstrating Residual Connections:")
    print(f"Original input: {x.squeeze()}")
    print(f"Transformation output: {output_no_residual.squeeze()}")
    print(f"With residual connection: {output_with_residual.squeeze()}")
    
    print("\nKey insights:")
    print("- Without residual: signal becomes very small (vanishing gradients)")
    print("- With residual: original signal is preserved + small modification")
    print("- This allows training very deep networks!")
    
    return x, output_no_residual, output_with_residual

# Demonstrate residual connections
original, no_res, with_res = demonstrate_residual_connections()

# Visualize the effect
positions = range(len(original.squeeze()))
values_original = original.squeeze().numpy()
values_no_res = no_res.squeeze().detach().numpy()
values_with_res = with_res.squeeze().detach().numpy()

plt.figure(figsize=(10, 6))
plt.plot(positions, values_original, 'o-', label='Original Input', linewidth=2, markersize=8)
plt.plot(positions, values_no_res, 's-', label='Without Residual', linewidth=2, markersize=8)
plt.plot(positions, values_with_res, '^-', label='With Residual', linewidth=2, markersize=8)

plt.xlabel('Feature Dimension')
plt.ylabel('Value')
plt.title('Effect of Residual Connections')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Complete Transformer Block

Now let's combine everything into a complete transformer block. The standard architecture is:

1. **Multi-Head Attention** with residual connection and layer norm
2. **Feed-Forward Network** with residual connection and layer norm

Using Pre-Norm architecture:
```
x = x + attention(layer_norm(x))
x = x + ffn(layer_norm(x))
```

In [None]:
from src.model.attention import MultiHeadAttention
from src.model.feedforward import FeedForward

class TransformerBlock(nn.Module):
    """Complete transformer block with visualization capabilities."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None, return_attention=False):
        """
        Forward pass through transformer block.
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Optional attention mask
            return_attention: Whether to return attention weights
        """
        # Store for visualization
        intermediates = {}
        intermediates['input'] = x.clone()
        
        # Multi-head attention with residual connection
        normed1 = self.norm1(x)
        intermediates['normed1'] = normed1.clone()
        
        if return_attention:
            attn_out, attention_weights = self.attention(normed1, normed1, normed1, mask, return_attention=True)
            intermediates['attention_weights'] = attention_weights
        else:
            attn_out = self.attention(normed1, normed1, normed1, mask)
            attention_weights = None
        
        attn_out = self.dropout(attn_out)
        x = x + attn_out  # Residual connection
        intermediates['after_attention'] = x.clone()
        
        # Feed-forward with residual connection
        normed2 = self.norm2(x)
        intermediates['normed2'] = normed2.clone()
        
        ff_out = self.feed_forward(normed2)
        ff_out = self.dropout(ff_out)
        x = x + ff_out  # Residual connection
        intermediates['output'] = x.clone()
        
        if return_attention:
            return x, attention_weights, intermediates
        return x, intermediates

# Create and test transformer block
d_model, n_heads, d_ff = 8, 2, 32
block = TransformerBlock(d_model, n_heads, d_ff)

# Create input
batch_size, seq_len = 1, 4
x = torch.randn(batch_size, seq_len, d_model)

print(f"Input shape: {x.shape}")
print(f"Model parameters: {sum(p.numel() for p in block.parameters()):,}")

# Forward pass
output, attention_weights, intermediates = block(x, return_attention=True)

print(f"\nOutput shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Analyze the transformations
print("\nStep-by-step analysis:")
input_norm = torch.norm(intermediates['input']).item()
after_attn_norm = torch.norm(intermediates['after_attention']).item()
output_norm = torch.norm(intermediates['output']).item()

print(f"Input norm: {input_norm:.3f}")
print(f"After attention norm: {after_attn_norm:.3f}")
print(f"Final output norm: {output_norm:.3f}")

# Show that residual connections preserve information
attention_contribution = torch.norm(intermediates['after_attention'] - intermediates['input']).item()
ff_contribution = torch.norm(intermediates['output'] - intermediates['after_attention']).item()

print(f"\nContribution analysis:")
print(f"Attention contribution: {attention_contribution:.3f}")
print(f"Feed-forward contribution: {ff_contribution:.3f}")
print("\nBoth components modify the input while preserving the original signal!")

## 5. Stacking Transformer Blocks

The power of transformers comes from stacking multiple blocks. Each block can learn different types of patterns and relationships. Let's see how information flows through a stack of blocks.

In [None]:
class SimpleTransformer(nn.Module):
    """Simple transformer with multiple blocks for analysis."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.n_layers = n_layers
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x, return_all_attention=False):
        """
        Forward pass through all transformer blocks.
        """
        all_attention_weights = []
        layer_outputs = [x.clone()]  # Store output from each layer
        
        for i, block in enumerate(self.blocks):
            if return_all_attention:
                x, attention_weights, _ = block(x, return_attention=True)
                all_attention_weights.append(attention_weights)
            else:
                x, _ = block(x)
            
            layer_outputs.append(x.clone())
        
        # Final layer normalization
        x = self.final_norm(x)
        layer_outputs.append(x.clone())
        
        if return_all_attention:
            return x, all_attention_weights, layer_outputs
        return x, layer_outputs

# Create a 3-layer transformer
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

# Create input
x = torch.randn(1, 4, 8)

print(f"Transformer with {n_layers} layers")
print(f"Total parameters: {sum(p.numel() for p in transformer.parameters()):,}")

# Forward pass
output, all_attention, layer_outputs = transformer(x, return_all_attention=True)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of attention matrices: {len(all_attention)}")
print(f"Number of layer outputs: {len(layer_outputs)}")

# Analyze how representations change through layers
print("\nRepresentation analysis through layers:")
for i, layer_output in enumerate(layer_outputs):
    norm = torch.norm(layer_output).item()
    mean = layer_output.mean().item()
    std = layer_output.std().item()
    
    if i == 0:
        layer_name = "Input"
    elif i <= n_layers:
        layer_name = f"Layer {i}"
    else:
        layer_name = "Final Norm"
    
    print(f"{layer_name:12}: norm={norm:6.3f}, mean={mean:6.3f}, std={std:6.3f}")

# Visualize attention patterns across layers
fig, axes = plt.subplots(1, n_layers, figsize=(15, 4))

for layer_idx in range(n_layers):
    # Average attention across heads
    avg_attention = all_attention[layer_idx][0].mean(dim=0).detach().numpy()
    
    sns.heatmap(
        avg_attention,
        annot=True, fmt='.2f',
        cmap='Blues',
        ax=axes[layer_idx],
        cbar=layer_idx == n_layers - 1
    )
    axes[layer_idx].set_title(f'Layer {layer_idx + 1}\nAttention')
    axes[layer_idx].set_xlabel('Keys')
    if layer_idx == 0:
        axes[layer_idx].set_ylabel('Queries')

plt.tight_layout()
plt.show()

print("\nNotice how different layers learn different attention patterns!")

## 6. Parameter Analysis

Let's analyze where most parameters are located in a transformer and how this scales with model size.

In [None]:
def analyze_transformer_parameters():
    """Analyze parameter distribution in transformers."""
    
    # Test different model sizes
    configs = [
        {'name': 'Tiny', 'd_model': 64, 'n_heads': 2, 'd_ff': 256, 'n_layers': 2},
        {'name': 'Small', 'd_model': 128, 'n_heads': 4, 'd_ff': 512, 'n_layers': 6},
        {'name': 'Medium', 'd_model': 256, 'n_heads': 8, 'd_ff': 1024, 'n_layers': 12},
        {'name': 'Large', 'd_model': 512, 'n_heads': 16, 'd_ff': 2048, 'n_layers': 24},
    ]
    
    results = []
    
    for config in configs:
        # Create single transformer block
        block = TransformerBlock(
            d_model=config['d_model'],
            n_heads=config['n_heads'],
            d_ff=config['d_ff']
        )
        
        # Count parameters by component
        attention_params = sum(p.numel() for p in block.attention.parameters())
        ff_params = sum(p.numel() for p in block.feed_forward.parameters())
        norm_params = sum(p.numel() for p in block.norm1.parameters()) + sum(p.numel() for p in block.norm2.parameters())
        
        total_per_block = attention_params + ff_params + norm_params
        total_model = total_per_block * config['n_layers']
        
        results.append({
            'name': config['name'],
            'attention': attention_params,
            'ff': ff_params,
            'norm': norm_params,
            'per_block': total_per_block,
            'total': total_model,
            'n_layers': config['n_layers']
        })
    
    # Display results
    print("Parameter Analysis by Model Size:")
    print("=" * 80)
    print(f"{'Model':<8} {'Attention':<12} {'Feed-Forward':<14} {'Layer Norm':<12} {'Per Block':<12} {'Total':<12}")
    print("-" * 80)
    
    for r in results:
        print(f"{r['name']:<8} {r['attention']:>10,} {r['ff']:>12,} {r['norm']:>10,} {r['per_block']:>10,} {r['total']:>10,}")
    
    # Calculate percentages for the medium model
    medium = results[2]  # Medium model
    print(f"\nParameter Distribution (Medium Model):")
    print(f"Attention: {medium['attention']/medium['per_block']*100:.1f}%")
    print(f"Feed-Forward: {medium['ff']/medium['per_block']*100:.1f}%")
    print(f"Layer Norm: {medium['norm']/medium['per_block']*100:.1f}%")
    
    # Visualize parameter distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Model size comparison
    model_names = [r['name'] for r in results]
    total_params = [r['total'] for r in results]
    
    ax1.bar(model_names, total_params, color='skyblue')
    ax1.set_ylabel('Total Parameters')
    ax1.set_title('Total Parameters by Model Size')
    ax1.set_yscale('log')
    
    # Add parameter counts as labels
    for i, v in enumerate(total_params):
        ax1.text(i, v, f'{v:,}', ha='center', va='bottom')
    
    # Component breakdown for medium model
    components = ['Attention', 'Feed-Forward', 'Layer Norm']
    component_params = [medium['attention'], medium['ff'], medium['norm']]
    colors = ['lightcoral', 'lightgreen', 'lightblue']
    
    ax2.pie(component_params, labels=components, colors=colors, autopct='%1.1f%%')
    ax2.set_title('Parameter Distribution\n(Medium Model, Per Block)')
    
    plt.tight_layout()
    plt.show()
    
    print("\nKey Insights:")
    print("• Most parameters are in the feed-forward networks (~67%)")
    print("• Attention mechanisms use ~33% of parameters")
    print("• Layer normalization uses <1% of parameters")
    print("• Parameters scale roughly as O(d_model²) due to linear layers")

analyze_transformer_parameters()

## Summary

In this notebook, we've built complete transformer blocks by combining:

1. **Feed-Forward Networks** - Process each position independently with 2-layer MLPs
2. **Layer Normalization** - Stabilize training by normalizing features
3. **Residual Connections** - Enable deep networks by preserving gradient flow
4. **Complete Blocks** - Combine attention + FFN with proper normalization
5. **Stacking** - Multiple blocks learn hierarchical representations

### Key Architecture Insights:

- **Pre-Norm vs Post-Norm**: Pre-norm (norm before sublayer) is more stable
- **Parameter Distribution**: ~67% in FFN, ~33% in attention
- **Residual Connections**: Essential for training deep networks
- **Layer Normalization**: Provides training stability

### Design Principles:

- Each component serves a specific purpose
- Residual connections preserve information flow
- Normalization enables stable training
- Stacking enables learning complex patterns

## 🚨 Common Misconceptions to Avoid

**❌ "Attention is like human attention"**  
→ ✅ It's more like **information routing** - deciding which information to send where

**❌ "Residuals just add skip connections"**  
→ ✅ They create **gradient highways** that enable deep network training

**❌ "Layer norm just normalizes"**  
→ ✅ It **stabilizes training dynamics** and enables faster convergence

**❌ "FFNs are just simple MLPs"**  
→ ✅ They're the **knowledge storage** - where most parameters and "facts" live

**❌ "Bigger models are always better"**  
→ ✅ There are **efficiency trade-offs** - bigger isn't always better for your use case

## 🔬 Try This Yourself!

Before moving on, experiment with these questions:
1. What happens if you remove residual connections? (Hint: try a 6-layer model)
2. How does attention change from layer 1 to layer 6?  
3. What if you make d_ff smaller or larger than 4×d_model?

Next, we'll explore how transformers understand position with positional encoding!