# Project 4: The Attention Surgeon - SOLUTION
## Building Scaled Dot-Product Attention from Scratch

**This notebook contains complete solutions to all tasks.**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple, Dict, List

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

## Part 1: Implement Core Attention Function - SOLUTION

In [None]:
def attention(Q: torch.Tensor,
              K: torch.Tensor,
              V: torch.Tensor,
              mask: Optional[torch.Tensor] = None,
              scale: bool = True,
              return_weights: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Scaled Dot-Product Attention - COMPLETE IMPLEMENTATION
    
    Attention(Q, K, V) = softmax(QK^T / âˆšd_k) V
    
    Args:
        Q: Query [batch, seq_len, d_k]
        K: Key [batch, seq_len, d_k]
        V: Value [batch, seq_len, d_v]
        mask: Optional mask [batch, seq_len, seq_len] or [seq_len, seq_len]
        scale: Whether to apply 1/âˆšd_k scaling
        return_weights: Whether to return attention weights
        
    Returns:
        output: [batch, seq_len, d_v]
        attention_weights: [batch, seq_len, seq_len] (if return_weights=True)
    """
    # Step 1: Compute Q @ K^T
    scores = torch.bmm(Q, K.transpose(1, 2))  # [batch, seq_len, seq_len]
    
    # Step 2: Scale by 1/âˆšd_k
    if scale:
        d_k = Q.size(-1)
        scores = scores / np.sqrt(d_k)
    
    # Step 3: Apply mask (set masked positions to -inf)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Handle NaN (occurs when entire row is masked)
    attention_weights = attention_weights.masked_fill(torch.isnan(attention_weights), 0.0)
    
    # Step 5: Multiply by V
    output = torch.bmm(attention_weights, V)
    
    if return_weights:
        return output, attention_weights
    else:
        return output, None

# Test the attention function
batch_size = 2
seq_len = 5
d_k = 8
d_v = 8

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_v)

output, weights = attention(Q, K, V)

print(f"Input shapes:")
print(f"  Q: {Q.shape}")
print(f"  K: {K.shape}")
print(f"  V: {V.shape}")
print(f"\nOutput shapes:")
print(f"  Output: {output.shape}")
print(f"  Attention weights: {weights.shape}")
print(f"\nAttention weights sum (should be ~1.0 per row):")
print(f"  {weights[0].sum(dim=-1)}")
print("\nâœ“ Attention implemented successfully!")

## Part 2: Visualize Attention Patterns - SOLUTION

In [None]:
def plot_attention_heatmap(attention_weights: torch.Tensor,
                          tokens: List[str],
                          title: str = "Attention Weights"):
    """
    Plot attention weight matrix as heatmap.
    """
    plt.figure(figsize=(10, 8))
    
    # Convert to numpy
    if isinstance(attention_weights, torch.Tensor):
        attention_weights = attention_weights.detach().cpu().numpy()
    
    # Plot heatmap
    sns.heatmap(attention_weights, 
                xticklabels=tokens, 
                yticklabels=tokens,
                cmap='viridis',
                annot=True,
                fmt='.2f',
                cbar_kws={'label': 'Attention Weight'})
    
    plt.xlabel('Key / Value')
    plt.ylabel('Query')
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Create sample tokens and attention
tokens = ['The', 'cat', 'sat', 'on', 'mat']
sample_attention = weights[0]  # Take first batch

plot_attention_heatmap(sample_attention, tokens)

## Part 3: Experiment with Scaling Factor - SOLUTION

In [None]:
def experiment_scaling_factor(d_k_values=[16, 64, 256, 1024]):
    """
    Test impact of scaling factor across different d_k values.
    """
    results = {
        'd_k': [],
        'with_scale_max': [],
        'without_scale_max': [],
        'with_scale_entropy': [],
        'without_scale_entropy': []
    }
    
    for d_k in d_k_values:
        # Create random Q, K, V
        Q = torch.randn(1, 10, d_k)
        K = torch.randn(1, 10, d_k)
        V = torch.randn(1, 10, d_k)
        
        # With scaling
        _, weights_scaled = attention(Q, K, V, scale=True)
        
        # Without scaling
        _, weights_unscaled = attention(Q, K, V, scale=False)
        
        # Compute statistics
        results['d_k'].append(d_k)
        results['with_scale_max'].append(weights_scaled.max().item())
        results['without_scale_max'].append(weights_unscaled.max().item())
        
        # Compute entropy (measure of distribution sharpness)
        def compute_entropy(w):
            w = w + 1e-9  # Avoid log(0)
            return -(w * torch.log(w)).sum(dim=-1).mean().item()
        
        results['with_scale_entropy'].append(compute_entropy(weights_scaled))
        results['without_scale_entropy'].append(compute_entropy(weights_unscaled))
    
    # Plot results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Max attention weight
    ax1.plot(results['d_k'], results['with_scale_max'], marker='o', label='With Scaling', linewidth=2)
    ax1.plot(results['d_k'], results['without_scale_max'], marker='s', label='Without Scaling', linewidth=2)
    ax1.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Saturation (1.0)')
    ax1.set_xlabel('d_k (Key Dimension)', fontsize=12)
    ax1.set_ylabel('Max Attention Weight', fontsize=12)
    ax1.set_title('Softmax Saturation vs d_k', fontsize=14)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xscale('log')
    
    # Entropy
    ax2.plot(results['d_k'], results['with_scale_entropy'], marker='o', label='With Scaling', linewidth=2)
    ax2.plot(results['d_k'], results['without_scale_entropy'], marker='s', label='Without Scaling', linewidth=2)
    ax2.set_xlabel('d_k (Key Dimension)', fontsize=12)
    ax2.set_ylabel('Entropy (nats)', fontsize=12)
    ax2.set_title('Attention Entropy vs d_k', fontsize=14)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xscale('log')
    
    plt.tight_layout()
    plt.show()
    
    # Print analysis
    print("\n" + "="*70)
    print("ANALYSIS: Why Scaling Matters")
    print("="*70)
    print("\nWithout scaling:")
    print("  - As d_k increases, QK^T values grow (variance ~ d_k)")
    print("  - Softmax becomes peaked (max â†’ 1.0)")
    print("  - Gradients vanish in saturated regions")
    print("  - Entropy decreases (less diverse attention)")
    print("\nWith scaling (1/âˆšd_k):")
    print("  - QK^T values normalized")
    print("  - Softmax remains balanced")
    print("  - Gradients flow properly")
    print("  - Entropy stable across different d_k")
    
    return results

results = experiment_scaling_factor()

## Part 4: Masking - SOLUTION

In [None]:
def create_causal_mask(seq_len: int, device='cpu') -> torch.Tensor:
    """
    Create lower-triangular mask for causal attention.
    
    Returns:
        mask: [seq_len, seq_len] with 1s in lower triangle, 0s in upper
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask

def create_reverse_mask(seq_len: int, device='cpu') -> torch.Tensor:
    """
    WRONG: Create upper-triangular mask (lets model cheat!).
    """
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device))
    return mask

# Visualize masks
seq_len = 8
causal_mask = create_causal_mask(seq_len)
reverse_mask = create_reverse_mask(seq_len)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))

# Causal mask
sns.heatmap(causal_mask.numpy(), ax=ax1, cmap='Blues', cbar=False, square=True, annot=True, fmt='.0f')
ax1.set_title('Causal Mask (Correct)\nCan see past and present', fontsize=12)
ax1.set_xlabel('Key Position')
ax1.set_ylabel('Query Position')

# Reverse mask
sns.heatmap(reverse_mask.numpy(), ax=ax2, cmap='Reds', cbar=False, square=True, annot=True, fmt='.0f')
ax2.set_title('Reverse Mask (Wrong!)\nCan see future!', fontsize=12)
ax2.set_xlabel('Key Position')
ax2.set_ylabel('Query Position')

# No mask
no_mask = torch.ones(seq_len, seq_len)
sns.heatmap(no_mask.numpy(), ax=ax3, cmap='Greens', cbar=False, square=True, annot=True, fmt='.0f')
ax3.set_title('No Mask (Bidirectional)\nCan see everything', fontsize=12)
ax3.set_xlabel('Key Position')
ax3.set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("\nMasking Examples:")
print("  1 = Can attend")
print("  0 = Cannot attend (masked out)")

### Test Causal Masking - SOLUTION

In [None]:
# Test attention with causal mask
seq_len = 6
d_k = 8

Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_k)

mask = create_causal_mask(seq_len)

# Without mask
_, weights_no_mask = attention(Q, K, V, mask=None)

# With causal mask
_, weights_masked = attention(Q, K, V, mask=mask)

# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(weights_no_mask[0].detach().numpy(), ax=ax1, cmap='viridis', annot=True, fmt='.2f', vmin=0, vmax=1)
ax1.set_title('Attention Without Mask (Bidirectional)')
ax1.set_xlabel('Key Position')
ax1.set_ylabel('Query Position')

sns.heatmap(weights_masked[0].detach().numpy(), ax=ax2, cmap='viridis', annot=True, fmt='.2f', vmin=0, vmax=1)
ax2.set_title('Attention With Causal Mask (Autoregressive)')
ax2.set_xlabel('Key Position')
ax2.set_ylabel('Query Position')

plt.tight_layout()
plt.show()

print("\nâœ“ Notice: With causal mask, upper triangle is all zeros!")
print("  This prevents the model from 'cheating' by seeing future tokens.")

## Part 5: Multi-Head Attention - SOLUTION

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module - COMPLETE IMPLEMENTATION
    """
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: [batch, seq_len, d_model]
            mask: [seq_len, seq_len]
        """
        batch_size, seq_len, d_model = x.shape
        
        # Linear projections
        Q = self.W_q(x)  # [batch, seq_len, d_model]
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into multiple heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Now: [batch, num_heads, seq_len, d_k]
        
        # Reshape for attention: [batch * num_heads, seq_len, d_k]
        Q = Q.contiguous().view(batch_size * self.num_heads, seq_len, self.d_k)
        K = K.contiguous().view(batch_size * self.num_heads, seq_len, self.d_k)
        V = V.contiguous().view(batch_size * self.num_heads, seq_len, self.d_k)
        
        # Apply attention
        output, attn_weights = attention(Q, K, V, mask=mask)
        
        # Reshape back: [batch, num_heads, seq_len, d_k]
        output = output.view(batch_size, self.num_heads, seq_len, self.d_k)
        
        # Concatenate heads: [batch, seq_len, d_model]
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # Final linear projection
        output = self.W_o(output)
        
        return output, attn_weights

# Test
d_model = 64
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)

x = torch.randn(2, 10, d_model)
output, _ = mha(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of heads: {num_heads}")
print(f"d_k per head: {mha.d_k}")
print(f"Total parameters: {sum(p.numel() for p in mha.parameters())}")
print("\nâœ“ Multi-Head Attention implemented successfully!")

## Part 6: Analysis and Reflection

### Questions and Answers:

#### 1. Why does increasing d_k without scaling cause softmax saturation?

**Answer:**
When Q and K are random vectors:
- Dot product QK^T has variance proportional to d_k
- As d_k increases, QK^T values grow larger
- Softmax of large values â†’ nearly one-hot distribution
- Gradients â‰ˆ 0 in saturated regions (vanishing gradients)
- Scaling by 1/âˆšd_k normalizes the variance to ~1

#### 2. What happens if we use -100 instead of -inf for masking?

**Answer:**
- softmax(-100) â‰ˆ 0, but not exactly 0
- Small numerical errors can accumulate
- Model might still "leak" tiny amounts of information
- -inf ensures mathematically exact masking (softmax(-inf) = 0)

#### 3. Can attention increase the rank of the input?

**Answer:**
No! Attention output is a weighted sum (convex combination) of V:
- Output = Î£ attention_weights[i] * V[i]
- Weighted sums cannot increase rank
- rank(Output) â‰¤ rank(V)
- This is why FFN layers are needed to increase expressiveness

#### 4. Why do we need multiple heads?

**Answer:**
Multiple heads allow learning different types of relationships:
- Head 1: Syntactic dependencies (subject-verb)
- Head 2: Semantic similarity
- Head 3: Positional patterns
- Each head can specialize in different patterns
- Analogous to multiple convolutional filters in CNNs

## ðŸŽ¯ Completion Checklist

- âœ… Implemented `attention()` function
- âœ… Visualized attention heatmap
- âœ… Tested scaling factor experiment
- âœ… Observed softmax saturation without scaling
- âœ… Implemented causal masking
- âœ… Visualized masked vs unmasked attention
- âœ… Implemented Multi-Head Attention
- âœ… Answered reflection questions

## Key Takeaways

1. **Scaling is critical**: 1/âˆšd_k prevents softmax saturation
2. **Masking enables causality**: -inf ensures no future information
3. **Multiple heads = multiple perspectives**: Different relationship types
4. **Attention is a routing mechanism**: Weighted sum, not transformation

## ðŸš€ Next Project
Move to **05_block_builder** to assemble the full Transformer block!