# Building Transformer Blocks

Attention is powerful, but alone it's just information routing. To build complete transformers, we need three more components:

## Why More Than Attention?

**Attention limitations**:
- Only mixes information (like shuffling cards)
- No position-wise processing
- Training instability in deep networks

**Complete transformer blocks add**:
- **Feed-Forward Networks**: Transform information, not just mix it
- **Layer Normalization**: Stabilize training dynamics  
- **Residual Connections**: Enable deep network training

## Learning Objectives

1. **Feed-Forward Networks**: The "thinking" component 
2. **Layer Normalization**: Training stability
3. **Residual Connections**: Gradient flow in deep networks
4. **Complete Transformer Block**: Integration of all components
5. **Stacking Blocks**: Building deep transformers

Let's build complete transformer blocks!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)
print("Environment setup complete!")

## Feed-Forward Networks: Position-wise Processing

**The Role**: After attention routes information between positions, each position needs individual processing.

**Architecture**: `FFN(x) = ReLU(xW₁ + b₁)W₂ + b₂`

**Key Properties**:
- **Position-wise**: Same transformation applied to each position independently
- **Expand-contract**: `d_model → d_ff → d_model` (typically 4× expansion)
- **Non-linear**: ReLU enables complex transformations

Think of it like specialized workstations - same tools, different inputs at each position.

In [None]:
from src.model.feedforward import FeedForward

d_model, d_ff = 8, 32
ff_net = FeedForward(d_model, d_ff, dropout=0.1)

print(f"Feed-Forward Network:")
print(f"d_model = {d_model} (input/output dimension)")
print(f"d_ff = {d_ff} (internal expansion, {d_ff//d_model}x larger)")

print(f"\nProcessing flow:")
print(f"Input:  [seq_len, {d_model}] → Each position has {d_model} features")
print(f"Expand: [seq_len, {d_ff}] → More space for complex processing")
print(f"Output: [seq_len, {d_model}] → Back to original dimension")

batch_size, seq_len = 1, 4
x = torch.randn(batch_size, seq_len, d_model)
output = ff_net(x)

print(f"\nDemonstration:")
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")

total_params = sum(p.numel() for p in ff_net.parameters())
linear1_params = d_model * d_ff + d_ff
linear2_params = d_ff * d_model + d_model

print(f"\nParameter breakdown:")
print(f"Linear1 ({d_model}→{d_ff}): {linear1_params:,} params")
print(f"Linear2 ({d_ff}→{d_model}): {linear2_params:,} params")
print(f"Total: {total_params:,} params")

print(f"\n✅ FFN provides position-wise transformation capability!")

## Layer Normalization: Training Stability

**The Problem**: Without normalization, different positions can have wildly different feature scales, breaking gradient descent.

**The Solution**: Normalize each position's features independently.

**Formula**: `LayerNorm(x) = γ · (x - μ) / σ + β`
- μ, σ: mean and std across features for each position
- γ, β: learnable scale and shift parameters

**Key Insight**: Each position is normalized separately, giving every position a "fresh start" with well-behaved values.

In [None]:
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0],
     [100.0, 200.0, 300.0, 400.0]]
])

print("🚨 PROBLEM: Different scales break training!")
print(f"Position 1: {x[0,0].tolist()}")
print(f"  → mean={x[0,0].mean():.1f}, std={x[0,0].std():.1f}")
print(f"Position 2: {x[0,1].tolist()}")
print(f"  → mean={x[0,1].mean():.1f}, std={x[0,1].std():.1f}")

layer_norm = nn.LayerNorm(4)
x_normalized = layer_norm(x)

print("\n✨ SOLUTION: LayerNorm fixes the scale problem!")
print(f"Position 1: {[round(val, 3) for val in x_normalized[0,0].tolist()]}")
print(f"  → mean={x_normalized[0,0].mean():.3f}, std={x_normalized[0,0].std():.3f}")
print(f"Position 2: {[round(val, 3) for val in x_normalized[0,1].tolist()]}")
print(f"  → mean={x_normalized[0,1].mean():.3f}, std={x_normalized[0,1].std():.3f}")

print(f"\nLayerNorm parameters:")
print(f"Scale (γ): {layer_norm.weight.tolist()}")
print(f"Shift (β): {layer_norm.bias.tolist()}")

print(f"\n✅ Both positions now have mean≈0, std≈1")
print(f"✅ Gradients can flow properly during training")

## Residual Connections: Gradient Highways

**The Problem**: Deep networks suffer from vanishing gradients - signals become weaker through many layers.

**The Solution**: Skip connections enable direct gradient flow.

**Formula**: `output = x + f(x)` instead of `output = f(x)`

**Why it works**: Gradient flows directly through the `+ x` path (gradient = 1), creating "gradient highways" even when `∇f(x)` vanishes.

In [None]:
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"Input: {x.squeeze().tolist()}")

weak_transform = nn.Linear(4, 4)
with torch.no_grad():
    weak_transform.weight.fill_(0.01)
    weak_transform.bias.zero_()

output_no_res = weak_transform(x)
print(f"Without residual: {[round(val, 3) for val in output_no_res.squeeze().tolist()]} (signal lost!)")

output_with_res = x + weak_transform(x)
print(f"With residual:    {[round(val, 3) for val in output_with_res.squeeze().tolist()]} (signal preserved!)")

print(f"\n✅ Residual connections preserve the original signal")
print(f"✅ Enable training of very deep networks")
print(f"✅ Create gradient highways preventing vanishing gradients")

## Complete Transformer Block

Now we integrate all components using **pre-norm architecture** (modern standard):

```
# Step 1: Attention sublayer
normed = LayerNorm(x)
attention_out = MultiHeadAttention(normed)  
x = x + attention_out  # Residual

# Step 2: Feed-forward sublayer
normed = LayerNorm(x)
ff_out = FeedForward(normed)
x = x + ff_out  # Residual
```

Pre-norm is more stable than post-norm for deep networks.

In [None]:
from src.model.attention import MultiHeadAttention

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Attention sublayer with pre-norm and residual
        normed = self.norm1(x)
        attn_out = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_out)
        
        # Feed-forward sublayer with pre-norm and residual
        normed = self.norm2(x)
        ff_out = self.feed_forward(normed)
        x = x + self.dropout(ff_out)
        
        return x

d_model, n_heads, d_ff = 8, 2, 32
block = TransformerBlock(d_model, n_heads, d_ff)

x = torch.randn(1, 4, d_model)
output = block(x)

print(f"Complete Transformer Block:")
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")

attention_params = sum(p.numel() for p in block.attention.parameters())
ff_params = sum(p.numel() for p in block.feed_forward.parameters())
norm_params = sum(p.numel() for p in [block.norm1, block.norm2])

print(f"\nParameter breakdown:")
print(f"Attention:    {attention_params:,}")
print(f"Feed-forward: {ff_params:,}")
print(f"Layer norms:  {norm_params:,}")
print(f"Total:        {attention_params + ff_params + norm_params:,}")

print(f"\n✅ Successfully combines all transformer components!")

## Stacking Transformer Blocks

The power of transformers comes from stacking multiple blocks. Each layer can learn increasingly complex patterns and relationships.

**Why stacking works**:
- Layer 1: Basic features and simple attention patterns  
- Layer 2: More complex interactions between positions
- Layer 3+: High-level reasoning and abstract relationships

Residual connections make deep stacking possible by preserving gradient flow.

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"Stacked Transformer:")
print(f"Layers:           {n_layers}")
print(f"Input shape:      {x.shape}")
print(f"Output shape:     {output.shape}")
print(f"Total parameters: {total_params:,}")
print(f"Per layer:        {params_per_layer:,}")

print(f"\n✨ Each layer learns different abstraction levels:")
print(f"• Layer 1: Basic features and attention patterns")
print(f"• Layer 2: More complex relationships")
print(f"• Layer 3: High-level abstractions and reasoning")

print(f"\n🔑 Residual connections enable deep stacking!")
print(f"✅ Each layer builds on previous understanding")
print(f"✅ Deep networks learn hierarchical representations")

## Summary: Complete Transformer Architecture

You've built complete transformer blocks from first principles!

### Key Components
- **Feed-Forward Networks**: Position-wise processing with expand-contract architecture
- **Layer Normalization**: Stabilizes training by normalizing each position's features
- **Residual Connections**: Enable deep networks via gradient highways
- **Integration**: Pre-norm architecture for stable deep training

### Architecture Pattern
Each transformer block follows: `x → LayerNorm → Attention → Residual → LayerNorm → FFN → Residual`

### Why It Works
- **Attention**: Routes information between positions
- **FFN**: Processes each position independently
- **LayerNorm**: Maintains stable feature scales
- **Residuals**: Preserve gradient flow in deep networks

### Next Steps
Now you understand complete transformer blocks! Next we'll explore positional encoding to give transformers spatial awareness.

Foundation complete - let's add position information! 🧭

In [None]:
# Create a simple multi-layer transformer for demonstration
class SimpleTransformer(nn.Module):
    """Stack of transformer blocks."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacking
print("📚 STACKING TRANSFORMER BLOCKS")
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"Layers:           {n_layers}")
print(f"Input shape:      {x.shape}")
print(f"Output shape:     {output.shape}")
print(f"Total parameters: {total_params:,}")
print(f"Per layer:        {params_per_layer:,}")

print(f"\n✨ Each layer can learn different patterns:")
print(f"• Layer 1: Basic features and attention patterns")
print(f"• Layer 2: More complex relationships") 
print(f"• Layer 3: High-level abstractions and reasoning")

print(f"\n🔑 KEY INSIGHT: Deep networks can learn hierarchical representations!")
print(f"✅ Residual connections make deep stacking possible")
print(f"✅ Each layer builds on previous layers' understanding")