# Building Transformer Blocks

Attention routes information between positions, but transformers need more components for complete functionality.

## Why More Than Attention?

**Attention limitations**:
- Only mixes information (no position-wise processing)  
- Can be unstable in deep networks
- No complex transformations

**Complete transformer blocks add**:
- **Feed-Forward Networks**: Position-wise processing and transformations
- **Layer Normalization**: Training stability  
- **Residual Connections**: Enable deep network training

## Architecture
Each transformer block: `Attention + FFN + LayerNorm + Residuals`

## Environment Setup

Import required libraries and copy attention implementation from previous notebook.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Optional

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)
print("Environment setup complete!")

## Feed-Forward Networks

Position-wise processing with expand-contract architecture: `d_model â†’ d_ff â†’ d_model`.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# Test feed-forward network
d_model, d_ff = 8, 32
ff_net = FeedForward(d_model, d_ff)

x = torch.randn(1, 4, d_model)
output = ff_net(x)

print(f"Feed-Forward Network:")
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Architecture: {d_model} â†’ {d_ff} â†’ {d_model} (expand-contract)")

total_params = sum(p.numel() for p in ff_net.parameters())
print(f"Parameters: {total_params:,}")
print("âœ… Position-wise processing with non-linear transformations!")

In [None]:
# Demonstrate layer normalization
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0],
     [100.0, 200.0, 300.0, 400.0]]
])

print("Problem: Different scales break training!")
print(f"Position 1: {x[0,0].tolist()}")
print(f"  â†’ mean={x[0,0].mean():.1f}, std={x[0,0].std():.1f}")
print(f"Position 2: {x[0,1].tolist()}")
print(f"  â†’ mean={x[0,1].mean():.1f}, std={x[0,1].std():.1f}")

layer_norm = nn.LayerNorm(4)
x_normalized = layer_norm(x)

print("\nSolution: LayerNorm fixes the scale problem!")
print(f"Position 1: {[round(val, 3) for val in x_normalized[0,0].tolist()]}")
print(f"  â†’ mean={x_normalized[0,0].mean():.3f}, std={x_normalized[0,0].std():.3f}")
print(f"Position 2: {[round(val, 3) for val in x_normalized[0,1].tolist()]}")
print(f"  â†’ mean={x_normalized[0,1].mean():.3f}, std={x_normalized[0,1].std():.3f}")

print("âœ… Both positions now have meanâ‰ˆ0, stdâ‰ˆ1")

## Residual Connections

Enable deep networks by creating gradient highways: `output = x + f(x)`.

In [ ]:
# Demonstrate residual connections
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"Input: {x.squeeze().tolist()}")

# Weak transformation that would lose signal
weak_transform = nn.Linear(4, 4)
with torch.no_grad():
    weak_transform.weight.fill_(0.01)
    weak_transform.bias.zero_()

output_no_res = weak_transform(x)
print(f"Without residual: {[round(val, 3) for val in output_no_res.squeeze().tolist()]} (signal lost!)")

output_with_res = x + weak_transform(x)
print(f"With residual:    {[round(val, 3) for val in output_with_res.squeeze().tolist()]} (signal preserved!)")

print("âœ… Residual connections preserve signals and enable deep networks")

## Complete Transformer Block

Integrate all components using pre-norm architecture for stability.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Attention sublayer with pre-norm and residual
        normed = self.norm1(x)
        attn_out = self.attention(normed, normed, normed, mask)
        x = x + self.dropout(attn_out)
        
        # Feed-forward sublayer with pre-norm and residual
        normed = self.norm2(x)
        ff_out = self.feed_forward(normed)
        x = x + self.dropout(ff_out)
        
        return x

# Test complete transformer block
d_model, n_heads, d_ff = 8, 2, 32
block = TransformerBlock(d_model, n_heads, d_ff)

x = torch.randn(1, 4, d_model)
output = block(x)

print(f"Complete Transformer Block:")
print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")

total_params = sum(p.numel() for p in block.parameters())
print(f"Total parameters: {total_params:,}")
print("âœ… Successfully combines attention + FFN + LayerNorm + residuals!")

## Stacking Transformer Blocks

Stack multiple blocks to build deep transformers with hierarchical learning.

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_heads, d_ff):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacked transformer
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())

print(f"Stacked Transformer:")
print(f"Layers: {n_layers}")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Total parameters: {total_params:,}")

print(f"\nHierarchical learning:")
print(f"â€¢ Layer 1: Basic features and simple patterns")
print(f"â€¢ Layer 2: More complex relationships")
print(f"â€¢ Layer 3: High-level abstractions")
print("âœ… Deep networks learn increasingly complex representations!")

## Summary

You've built complete transformer blocks from first principles!

**Key Components**:
- **Feed-Forward Networks**: Position-wise processing with expand-contract architecture
- **Layer Normalization**: Stabilizes training by normalizing feature scales
- **Residual Connections**: Enable deep networks via gradient highways
- **Integration**: Pre-norm architecture for stable deep training

**Architecture Pattern**: `x â†’ LayerNorm â†’ Attention â†’ Residual â†’ LayerNorm â†’ FFN â†’ Residual`

**Why It Works**:
- Attention routes information between positions
- FFN processes each position independently  
- LayerNorm maintains stable scales
- Residuals preserve gradient flow

**Next**: Add positional encoding to give transformers spatial awareness!

class SimpleTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_heads, d_ff):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacked transformer
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())

print(f"Stacked Transformer:")
print(f"Layers: {n_layers}")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Total parameters: {total_params:,}")

print(f"\nHierarchical learning:")
print(f"â€¢ Layer 1: Basic features and simple patterns")
print(f"â€¢ Layer 2: More complex relationships")
print(f"â€¢ Layer 3: High-level abstractions")
print("âœ… Deep networks learn increasingly complex representations!")

## Stacking Transformer Blocks

The power of transformers comes from stacking multiple blocks. Each layer can learn increasingly complex patterns and relationships.

**Why stacking works**:
- Layer 1: Basic features and simple attention patterns  
- Layer 2: More complex interactions between positions
- Layer 3+: High-level reasoning and abstract relationships

Residual connections make deep stacking possible by preserving gradient flow.

In [None]:
# Create a simple multi-layer transformer for demonstration
class SimpleTransformer(nn.Module):
    """Stack of transformer blocks."""
    
    def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int):
        super().__init__()
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(n_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return self.final_norm(x)

# Test stacking
print("ðŸ“š STACKING TRANSFORMER BLOCKS")
n_layers = 3
transformer = SimpleTransformer(n_layers, d_model=8, n_heads=2, d_ff=32)

x = torch.randn(1, 4, 8)
output = transformer(x)

total_params = sum(p.numel() for p in transformer.parameters())
params_per_layer = total_params // n_layers

print(f"Layers:           {n_layers}")
print(f"Input shape:      {x.shape}")
print(f"Output shape:     {output.shape}")
print(f"Total parameters: {total_params:,}")
print(f"Per layer:        {params_per_layer:,}")

print(f"\nâœ¨ Each layer can learn different patterns:")
print(f"â€¢ Layer 1: Basic features and attention patterns")
print(f"â€¢ Layer 2: More complex relationships") 
print(f"â€¢ Layer 3: High-level abstractions and reasoning")

print(f"\nðŸ”‘ KEY INSIGHT: Deep networks can learn hierarchical representations!")
print(f"âœ… Residual connections make deep stacking possible")
print(f"âœ… Each layer builds on previous layers' understanding")