# The Transformer Architecture

Welcome to Topic 4! In this notebook, we'll explore the complete transformer architecture in detail. We'll understand how all the components work together to create this revolutionary model.

## Learning Objectives

By the end of this notebook, you will:
- Understand the complete transformer architecture
- Learn about encoder and decoder stacks
- Master positional encoding
- Understand layer normalization and residual connections
- See how transformers process sequences

## Setup and Imports

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple
import math

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Configure matplotlib
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Transformer Architecture Overview

Let's start by visualizing the complete transformer architecture.

In [None]:
# Visualize the transformer architecture
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))

# Encoder side
ax1.set_title('Transformer Encoder', fontsize=16, fontweight='bold')
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 10)

# Draw encoder components
encoder_components = [
    (5, 1, 'Input Embeddings', 'lightblue'),
    (5, 2, 'Positional Encoding', 'lightgreen'),
    (5, 3.5, 'Multi-Head Attention', 'lightyellow'),
    (5, 4.5, 'Add & Norm', 'lightgray'),
    (5, 6, 'Feed Forward', 'lightcoral'),
    (5, 7, 'Add & Norm', 'lightgray'),
    (5, 8.5, 'Output', 'lightblue')
]

for x, y, label, color in encoder_components:
    rect = plt.Rectangle((x-2, y-0.4), 4, 0.8, facecolor=color, edgecolor='black', linewidth=2)
    ax1.add_patch(rect)
    ax1.text(x, y, label, ha='center', va='center', fontweight='bold')

# Draw connections
for i in range(len(encoder_components)-1):
    y1 = encoder_components[i][1] + 0.4
    y2 = encoder_components[i+1][1] - 0.4
    ax1.arrow(5, y1, 0, y2-y1, head_width=0.2, head_length=0.1, fc='black', ec='black')

# Add "Nx" notation
ax1.text(8.5, 5.5, 'Nx', fontsize=14, fontweight='bold', bbox=dict(boxstyle="round", facecolor='white'))

ax1.axis('off')

# Decoder side
ax2.set_title('Transformer Decoder', fontsize=16, fontweight='bold')
ax2.set_xlim(0, 10)
ax2.set_ylim(0, 10)

# Draw decoder components
decoder_components = [
    (5, 0.5, 'Output Embeddings', 'lightblue'),
    (5, 1.5, 'Positional Encoding', 'lightgreen'),
    (5, 2.5, 'Masked Multi-Head Attention', 'lightyellow'),
    (5, 3.3, 'Add & Norm', 'lightgray'),
    (5, 4.5, 'Multi-Head Attention', 'lightyellow'),
    (5, 5.3, 'Add & Norm', 'lightgray'),
    (5, 6.5, 'Feed Forward', 'lightcoral'),
    (5, 7.3, 'Add & Norm', 'lightgray'),
    (5, 8.5, 'Linear & Softmax', 'lightblue')
]

for x, y, label, color in decoder_components:
    rect = plt.Rectangle((x-2.5, y-0.35), 5, 0.7, facecolor=color, edgecolor='black', linewidth=2)
    ax2.add_patch(rect)
    ax2.text(x, y, label, ha='center', va='center', fontweight='bold', fontsize=10)

# Draw connections
for i in range(len(decoder_components)-1):
    y1 = decoder_components[i][1] + 0.35
    y2 = decoder_components[i+1][1] - 0.35
    ax2.arrow(5, y1, 0, y2-y1, head_width=0.2, head_length=0.1, fc='black', ec='black')

# Add encoder-decoder connection
ax2.arrow(1, 4.5, 1.5, 0, head_width=0.2, head_length=0.1, fc='red', ec='red', linestyle='--')
ax2.text(1, 4.8, 'From\nEncoder', ha='center', fontsize=10, color='red')

# Add "Nx" notation
ax2.text(8.5, 5, 'Nx', fontsize=14, fontweight='bold', bbox=dict(boxstyle="round", facecolor='white'))

ax2.axis('off')

plt.tight_layout()
plt.show()

print("Key Architecture Components:")
print("1. Input/Output Embeddings: Convert tokens to vectors")
print("2. Positional Encoding: Add position information")
print("3. Multi-Head Attention: Core attention mechanism")
print("4. Feed Forward: Position-wise transformations")
print("5. Add & Norm: Residual connections and layer normalization")
print("6. Nx: Stack of N identical layers (typically 6-24)")

## 2. Positional Encoding

Since transformers don't have inherent position awareness, we need to add positional information.

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding using sine and cosine functions."""
    
    def __init__(self, d_model: int, max_seq_length: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length).unsqueeze(1).float()
        
        # Create div_term for the sinusoidal pattern
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encoding to input embeddings."""
        # x shape: [batch_size, seq_length, d_model]
        seq_length = x.size(1)
        x = x + self.pe[:, :seq_length, :]
        return self.dropout(x)

# Visualize positional encoding
d_model = 512
max_length = 100
pos_encoder = PositionalEncoding(d_model, max_length)

# Get positional encoding values
pe_values = pos_encoder.pe[0, :max_length, :].numpy()

# Plot heatmap
plt.figure(figsize=(12, 6))
plt.imshow(pe_values.T, aspect='auto', cmap='RdBu', interpolation='nearest')
plt.colorbar(label='Value')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Positional Encoding Pattern')
plt.tight_layout()
plt.show()

# Plot specific dimensions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
dimensions = [0, 1, 10, 100]

for ax, dim in zip(axes.flat, dimensions):
    ax.plot(pe_values[:, dim])
    ax.set_title(f'Dimension {dim}')
    ax.set_xlabel('Position')
    ax.set_ylabel('Encoding Value')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Positional Encoding Properties:")
print(f"- Each position has a unique encoding vector of dimension {d_model}")
print("- Uses sinusoidal functions with different frequencies")
print("- Allows the model to learn relative positions")
print("- Can extrapolate to longer sequences than seen during training")

## 3. Multi-Head Attention Layer

Let's implement the complete multi-head attention mechanism.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism."""
    
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of multi-head attention."""
        batch_size = query.size(0)
        seq_length = query.size(1)
        
        # 1. Linear projections in batch from d_model => h x d_k
        Q = self.W_q(query).view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. Apply attention on all the projected vectors in batch
        attention_output, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )
        
        # 3. Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, seq_length, self.d_model
        )
        
        output = self.W_o(attention_output)
        
        return output, attention_weights
    
    def scaled_dot_product_attention(self, Q: torch.Tensor, K: torch.Tensor, 
                                   V: torch.Tensor, mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """Scaled dot-product attention."""
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

# Test multi-head attention
d_model = 512
n_heads = 8
seq_length = 10
batch_size = 2

mha = MultiHeadAttention(d_model, n_heads)

# Create sample input
x = torch.randn(batch_size, seq_length, d_model)

# Forward pass
output, attention_weights = mha(x, x, x)  # Self-attention

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Visualize attention weights for one head
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights[0, 0].detach().numpy(), 
            cmap='Blues', cbar=True, square=True,
            xticklabels=range(seq_length),
            yticklabels=range(seq_length))
plt.xlabel('Keys')
plt.ylabel('Queries')
plt.title('Attention Weights (Head 0, Batch 0)')
plt.tight_layout()
plt.show()

## 4. Feed-Forward Network

The position-wise feed-forward network applies the same transformation to each position.

In [None]:
class FeedForwardNetwork(nn.Module):
    """Position-wise feed-forward network."""
    
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: Linear -> ReLU -> Dropout -> Linear."""
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Visualize feed-forward network
d_model = 512
d_ff = 2048
ffn = FeedForwardNetwork(d_model, d_ff)

# Create sample input
x = torch.randn(1, 10, d_model)
output = ffn(x)

print(f"FFN Architecture:")
print(f"Input: {x.shape} -> Linear({d_model}, {d_ff}) -> ReLU -> Dropout -> Linear({d_ff}, {d_model}) -> Output: {output.shape}")

# Visualize the transformation
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Input distribution
ax1.hist(x.numpy().flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black')
ax1.set_title('Input Distribution')
ax1.set_xlabel('Value')
ax1.set_ylabel('Frequency')

# Output distribution
ax2.hist(output.detach().numpy().flatten(), bins=50, alpha=0.7, color='green', edgecolor='black')
ax2.set_title('Output Distribution (after FFN)')
ax2.set_xlabel('Value')
ax2.set_ylabel('Frequency')

plt.tight_layout()
plt.show()

print("\nFeed-Forward Network Properties:")
print("- Applied independently to each position")
print("- Two linear transformations with ReLU activation")
print(f"- Hidden dimension ({d_ff}) is typically 4x the model dimension ({d_model})")
print("- Adds non-linearity to the model")

## 5. Layer Normalization and Residual Connections

These components are crucial for training deep transformers.

In [None]:
class LayerNorm(nn.Module):
    """Layer normalization."""
    
    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize across the feature dimension."""
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        normalized = (x - mean) / (std + self.eps)
        return self.gamma * normalized + self.beta

class ResidualConnection(nn.Module):
    """Residual connection with layer normalization."""
    
    def __init__(self, d_model: int, dropout: float = 0.1):
        super().__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        """Apply residual connection to any sublayer."""
        return x + self.dropout(sublayer(self.norm(x)))

# Demonstrate layer normalization
batch_size = 2
seq_length = 10
d_model = 512

# Create sample data with different scales
x = torch.randn(batch_size, seq_length, d_model) * 10 + 5

# Apply layer normalization
layer_norm = LayerNorm(d_model)
x_normalized = layer_norm(x)

# Visualize the effect
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Before normalization
axes[0, 0].hist(x[0, 0].numpy(), bins=50, alpha=0.7, color='red', edgecolor='black')
axes[0, 0].set_title('Before LayerNorm (Position 0)')
axes[0, 0].set_xlabel('Value')
axes[0, 0].set_ylabel('Frequency')

# After normalization
axes[0, 1].hist(x_normalized[0, 0].detach().numpy(), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[0, 1].set_title('After LayerNorm (Position 0)')
axes[0, 1].set_xlabel('Value')
axes[0, 1].set_ylabel('Frequency')

# Statistics before
axes[1, 0].text(0.1, 0.8, f"Mean: {x[0, 0].mean():.4f}", transform=axes[1, 0].transAxes, fontsize=14)
axes[1, 0].text(0.1, 0.6, f"Std: {x[0, 0].std():.4f}", transform=axes[1, 0].transAxes, fontsize=14)
axes[1, 0].text(0.1, 0.4, f"Min: {x[0, 0].min():.4f}", transform=axes[1, 0].transAxes, fontsize=14)
axes[1, 0].text(0.1, 0.2, f"Max: {x[0, 0].max():.4f}", transform=axes[1, 0].transAxes, fontsize=14)
axes[1, 0].set_title('Statistics Before')
axes[1, 0].axis('off')

# Statistics after
axes[1, 1].text(0.1, 0.8, f"Mean: {x_normalized[0, 0].mean():.4f}", transform=axes[1, 1].transAxes, fontsize=14)
axes[1, 1].text(0.1, 0.6, f"Std: {x_normalized[0, 0].std():.4f}", transform=axes[1, 1].transAxes, fontsize=14)
axes[1, 1].text(0.1, 0.4, f"Min: {x_normalized[0, 0].min():.4f}", transform=axes[1, 1].transAxes, fontsize=14)
axes[1, 1].text(0.1, 0.2, f"Max: {x_normalized[0, 0].max():.4f}", transform=axes[1, 1].transAxes, fontsize=14)
axes[1, 1].set_title('Statistics After')
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

print("Layer Normalization Benefits:")
print("- Stabilizes training of deep networks")
print("- Reduces internal covariate shift")
print("- Allows higher learning rates")
print("- Normalizes across features (not batch)")

## 6. Encoder Layer

Now let's combine all components into a complete encoder layer.

In [None]:
class EncoderLayer(nn.Module):
    """Single encoder layer."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForwardNetwork(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass through encoder layer."""
        # Self-attention with residual connection and layer norm
        attn_output, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

# Test encoder layer
encoder_layer = EncoderLayer(d_model=512, n_heads=8, d_ff=2048)
x = torch.randn(2, 10, 512)  # [batch_size, seq_length, d_model]
output = encoder_layer(x)

print(f"Encoder Layer:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nNumber of parameters: {sum(p.numel() for p in encoder_layer.parameters()):,}")

# Visualize information flow
fig, ax = plt.subplots(1, 1, figsize=(10, 8))

# Draw encoder layer components
components = [
    (5, 8, 'Input', 'lightblue'),
    (5, 7, 'Multi-Head\nSelf-Attention', 'lightyellow'),
    (5, 6, 'Dropout', 'lightgray'),
    (5, 5, 'Add & Norm', 'lightgreen'),
    (5, 4, 'Feed Forward', 'lightcoral'),
    (5, 3, 'Dropout', 'lightgray'),
    (5, 2, 'Add & Norm', 'lightgreen'),
    (5, 1, 'Output', 'lightblue')
]

for x, y, label, color in components:
    if label in ['Add & Norm']:
        rect = plt.Rectangle((x-1.5, y-0.3), 3, 0.6, facecolor=color, edgecolor='black', linewidth=2)
    else:
        rect = plt.Rectangle((x-2, y-0.3), 4, 0.6, facecolor=color, edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(x, y, label, ha='center', va='center', fontweight='bold', fontsize=10)

# Draw connections
# Main flow
for i in range(len(components)-1):
    y1 = components[i][1] - 0.3
    y2 = components[i+1][1] + 0.3
    ax.arrow(5, y1, 0, y2-y1, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Residual connections
# First residual
ax.arrow(7.5, 7.7, 0, -2.4, head_width=0.15, head_length=0.1, fc='red', ec='red', linestyle='--', linewidth=2)
ax.arrow(7.5, 5.3, -2, 0, head_width=0.15, head_length=0.1, fc='red', ec='red', linestyle='--', linewidth=2)

# Second residual
ax.arrow(7.5, 4.7, 0, -2.4, head_width=0.15, head_length=0.1, fc='red', ec='red', linestyle='--', linewidth=2)
ax.arrow(7.5, 2.3, -2, 0, head_width=0.15, head_length=0.1, fc='red', ec='red', linestyle='--', linewidth=2)

ax.text(8, 6.5, 'Residual', color='red', fontsize=10, rotation=-90)
ax.text(8, 3.5, 'Residual', color='red', fontsize=10, rotation=-90)

ax.set_xlim(0, 10)
ax.set_ylim(0, 9)
ax.set_title('Encoder Layer Information Flow', fontsize=16, fontweight='bold')
ax.axis('off')

plt.tight_layout()
plt.show()

## 7. Decoder Layer

The decoder layer is similar but includes masked self-attention and encoder-decoder attention.

In [None]:
class DecoderLayer(nn.Module):
    """Single decoder layer."""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.masked_self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.encoder_decoder_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForwardNetwork(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass through decoder layer."""
        # Masked self-attention
        attn_output, _ = self.masked_self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Encoder-decoder attention
        attn_output, _ = self.encoder_decoder_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

# Create causal mask for decoder
def create_causal_mask(seq_length: int) -> torch.Tensor:
    """Create causal mask to prevent attending to future positions."""
    mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
    return mask == 0

# Test decoder layer
decoder_layer = DecoderLayer(d_model=512, n_heads=8, d_ff=2048)

# Sample inputs
tgt = torch.randn(2, 10, 512)  # Decoder input
memory = torch.randn(2, 15, 512)  # Encoder output
tgt_mask = create_causal_mask(10)

output = decoder_layer(tgt, memory, tgt_mask=tgt_mask)

print(f"Decoder Layer:")
print(f"Target input shape: {tgt.shape}")
print(f"Encoder output shape: {memory.shape}")
print(f"Output shape: {output.shape}")
print(f"\nNumber of parameters: {sum(p.numel() for p in decoder_layer.parameters()):,}")

# Visualize causal mask
plt.figure(figsize=(8, 8))
plt.imshow(tgt_mask.numpy(), cmap='Blues', interpolation='nearest')
plt.colorbar(label='Can Attend')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask (Prevents Looking at Future Tokens)')
plt.tight_layout()
plt.show()

print("\nCausal Mask Properties:")
print("- Each position can only attend to previous positions")
print("- Ensures autoregressive property during generation")
print("- Upper triangle is masked (set to -infinity before softmax)")

## 8. Complete Transformer Model

Now let's put everything together into a complete transformer model.

In [None]:
class Transformer(nn.Module):
    """Complete Transformer model."""
    
    def __init__(self, 
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 d_model: int = 512,
                 n_heads: int = 8,
                 n_encoder_layers: int = 6,
                 n_decoder_layers: int = 6,
                 d_ff: int = 2048,
                 max_seq_length: int = 5000,
                 dropout: float = 0.1):
        super().__init__()
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_decoder_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        """Initialize parameters with Xavier uniform."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def encode(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Encode source sequence."""
        # Embed and add positional encoding
        x = self.src_embedding(src) * math.sqrt(self.src_embedding.embedding_dim)
        x = self.positional_encoding(x)
        
        # Pass through encoder layers
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
            
        return x
    
    def decode(self, tgt: torch.Tensor, memory: torch.Tensor,
               src_mask: Optional[torch.Tensor] = None,
               tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Decode target sequence."""
        # Embed and add positional encoding
        x = self.tgt_embedding(tgt) * math.sqrt(self.tgt_embedding.embedding_dim)
        x = self.positional_encoding(x)
        
        # Pass through decoder layers
        for layer in self.decoder_layers:
            x = layer(x, memory, src_mask, tgt_mask)
            
        return x
    
    def forward(self, src: torch.Tensor, tgt: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Forward pass through the transformer."""
        # Encode source
        memory = self.encode(src, src_mask)
        
        # Decode target
        output = self.decode(tgt, memory, src_mask, tgt_mask)
        
        # Project to vocabulary
        output = self.output_projection(output)
        
        return output

# Create a small transformer model
model = Transformer(
    src_vocab_size=1000,
    tgt_vocab_size=1000,
    d_model=256,
    n_heads=4,
    n_encoder_layers=2,
    n_decoder_layers=2,
    d_ff=1024,
    dropout=0.1
)

# Test forward pass
src = torch.randint(0, 1000, (2, 10))  # [batch_size, src_seq_length]
tgt = torch.randint(0, 1000, (2, 8))   # [batch_size, tgt_seq_length]
tgt_mask = create_causal_mask(8)

output = model(src, tgt, tgt_mask=tgt_mask)

print(f"Complete Transformer Model:")
print(f"Source shape: {src.shape}")
print(f"Target shape: {tgt.shape}")
print(f"Output shape: {output.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

# Parameter breakdown
print("\nParameter Breakdown:")
for name, module in model.named_children():
    params = sum(p.numel() for p in module.parameters())
    print(f"{name}: {params:,} parameters")

## 9. Inference Example

Let's see how the transformer generates sequences during inference.

In [None]:
@torch.no_grad()
def greedy_decode(model: Transformer, src: torch.Tensor, max_length: int = 50,
                  start_token: int = 1, end_token: int = 2) -> torch.Tensor:
    """Greedy decoding for sequence generation."""
    model.eval()
    
    # Encode source
    memory = model.encode(src)
    
    # Start with start token
    ys = torch.ones(1, 1).fill_(start_token).type_as(src)
    
    for i in range(max_length - 1):
        # Create mask
        tgt_mask = create_causal_mask(ys.size(1)).type_as(src)
        
        # Decode
        out = model.decode(ys, memory, tgt_mask=tgt_mask)
        
        # Project and get next token
        prob = model.output_projection(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        
        # Append to sequence
        ys = torch.cat([ys, torch.ones(1, 1).fill_(next_word).type_as(src)], dim=1)
        
        # Stop if end token is generated
        if next_word == end_token:
            break
            
    return ys

# Simulate translation task
print("Simulating Translation with Greedy Decoding:")
print("="*50)

# Create source sequence
src_sequence = torch.randint(3, 100, (1, 8))  # Random source tokens
print(f"Source tokens: {src_sequence.tolist()[0]}")

# Generate translation
translation = greedy_decode(model, src_sequence, max_length=15)
print(f"Generated tokens: {translation.tolist()[0]}")
print(f"Generated length: {translation.size(1)}")

# Visualize attention patterns during generation
def visualize_generation_attention(model, src, tgt):
    """Visualize attention during generation."""
    model.eval()
    
    # Get encoder output
    memory = model.encode(src)
    
    # Get decoder attention weights (simplified - would need to modify model to return these)
    tgt_mask = create_causal_mask(tgt.size(1))
    
    # For visualization, we'll create synthetic attention patterns
    seq_len = tgt.size(1)
    src_len = src.size(1)
    
    # Self-attention pattern (causal)
    self_attn = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        for j in range(i+1):
            self_attn[i, j] = 1.0 / (i + 1)  # Uniform over visible positions
            
    # Cross-attention pattern (attending to source)
    cross_attn = torch.rand(seq_len, src_len)
    cross_attn = F.softmax(cross_attn, dim=-1)
    
    return self_attn, cross_attn

# Visualize attention
self_attn, cross_attn = visualize_generation_attention(model, src_sequence, translation)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Self-attention
sns.heatmap(self_attn.numpy(), cmap='Blues', cbar=True, square=True, ax=ax1)
ax1.set_xlabel('Position')
ax1.set_ylabel('Position')
ax1.set_title('Decoder Self-Attention (Causal)')

# Cross-attention
sns.heatmap(cross_attn.numpy(), cmap='Reds', cbar=True, ax=ax2)
ax2.set_xlabel('Source Position')
ax2.set_ylabel('Target Position')
ax2.set_title('Encoder-Decoder Cross-Attention')

plt.tight_layout()
plt.show()

## Summary

In this notebook, we've explored the complete transformer architecture:

1. **Architecture Overview**: Encoder-decoder structure with self-attention
2. **Positional Encoding**: Adding position information with sinusoidal functions
3. **Multi-Head Attention**: Parallel attention mechanisms
4. **Feed-Forward Networks**: Position-wise transformations
5. **Layer Normalization**: Stabilizing deep network training
6. **Residual Connections**: Enabling gradient flow
7. **Complete Model**: Putting all components together
8. **Inference**: Generating sequences with the model

The transformer's power comes from:
- **Parallelization**: All positions processed simultaneously
- **Long-range Dependencies**: Direct connections between all positions
- **Scalability**: Can be made deeper and wider
- **Flexibility**: Works for many sequence tasks

Next, we'll implement a transformer from scratch to deepen our understanding!