# Building a Toy Transformer from Scratch

This notebook demonstrates how to build a simplified Transformer model from scratch. We'll cover:

1. Core components of the Transformer
2. Implementation of each component
3. Training on a simple task
4. Visualizing the model's behavior

## Transformer Architecture Overview

The Transformer consists of several key components:

1. **Input Embedding**: Convert input tokens to vectors
2. **Positional Encoding**: Add position information to embeddings
3. **Multi-Head Attention**: Process relationships between tokens
4. **Feed-Forward Network**: Transform token representations
5. **Layer Normalization**: Stabilize training
6. **Residual Connections**: Help with gradient flow

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Positional Encoding

First, let's implement the positional encoding layer:

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_length: int = 5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices and cosine to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension and register as buffer
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encoding to input embeddings."""
        return x + self.pe[:, :x.size(1)]

## 2. Multi-Head Attention

Next, let's implement the multi-head attention mechanism:

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        
        # Output projection
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)
        
        # Project Q, K, V
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)
        
        # Reshape for multi-head attention
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        context = torch.matmul(attention_weights, v)
        
        # Reshape back
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Project output
        output = self.out_proj(context)
        
        return output, attention_weights

## 3. Feed-Forward Network

Now, let's implement the feed-forward network:

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

## 4. Transformer Encoder Layer

Let's combine these components into an encoder layer:

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        # Multi-head attention
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self-attention with residual connection and layer normalization
        attn_output, attention_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer normalization
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x, attention_weights

## 5. Complete Transformer Model

Now, let's build the complete Transformer model:

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        num_layers: int = 6,
        d_ff: int = 2048,
        dropout: float = 0.1,
        max_seq_length: int = 5000
    ):
        super().__init__()
        
        # Token embedding
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        
        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # Token embedding and positional encoding
        x = self.embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Store attention weights
        attention_weights = []
        
        # Process through encoder layers
        for layer in self.encoder_layers:
            x, attn_weights = layer(x, mask)
            attention_weights.append(attn_weights)
        
        # Project to vocabulary
        output = self.output_proj(x)
        
        return output, attention_weights

## 6. Training on a Simple Task

Let's train our Transformer on a simple sequence prediction task:

In [None]:
def generate_sequence_data(num_samples: int, seq_length: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate synthetic sequence data for training."""
    # Generate random sequences
    x = torch.randint(1, 10, (num_samples, seq_length))
    
    # Target is the sequence shifted by one position
    y = torch.roll(x, -1, dims=1)
    y[:, -1] = 0  # Padding token
    
    return x, y

def train_transformer(
    model: Transformer,
    num_epochs: int,
    batch_size: int,
    seq_length: int,
    learning_rate: float = 0.001
) -> List[float]:
    """Train the Transformer model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    losses = []
    
    for epoch in range(num_epochs):
        # Generate training data
        x, y = generate_sequence_data(batch_size, seq_length)
        
        # Forward pass
        output, _ = model(x)
        
        # Compute loss
        loss = criterion(output.view(-1, output.size(-1)), y.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    
    return losses

# Create and train model
vocab_size = 10  # Small vocabulary for demonstration
d_model = 64    # Small model for faster training
num_heads = 4
num_layers = 2
d_ff = 256

model = Transformer(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=d_ff
)

losses = train_transformer(
    model=model,
    num_epochs=100,
    batch_size=32,
    seq_length=10
)

## 7. Visualizing Training Progress

Let's plot the training loss:

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

## 8. Visualizing Attention Patterns

Let's examine the attention patterns learned by our model:

In [None]:
def visualize_attention_patterns(
    model: Transformer,
    input_sequence: torch.Tensor,
    layer_idx: int = -1
) -> None:
    """Visualize attention patterns for a given input sequence."""
    # Get model output and attention weights
    with torch.no_grad():
        _, attention_weights = model(input_sequence.unsqueeze(0))
    
    # Get attention weights for the specified layer
    layer_attention = attention_weights[layer_idx][0, 0]  # First batch, first head
    
    # Plot attention heatmap
    plt.figure(figsize=(8, 8))
    sns.heatmap(
        layer_attention,
        cmap='viridis',
        xticklabels=input_sequence.tolist(),
        yticklabels=input_sequence.tolist()
    )
    plt.title(f'Attention Patterns (Layer {layer_idx})')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.show()

# Generate a test sequence
test_sequence = torch.randint(1, 10, (10,))
print(f"Test sequence: {test_sequence.tolist()}")

# Visualize attention patterns for different layers
for layer_idx in range(num_layers):
    visualize_attention_patterns(model, test_sequence, layer_idx)

## Conclusion

In this notebook, we've built a toy Transformer model from scratch and explored its components:

1. Positional encoding
2. Multi-head attention
3. Feed-forward networks
4. Layer normalization
5. Residual connections

Key takeaways:

- The Transformer architecture is modular and elegant
- Each component serves a specific purpose
- Attention patterns provide insights into model behavior

This implementation is simplified for educational purposes. Real-world Transformers (like BERT, GPT) use more sophisticated components and training techniques.