In [None]:
# Setup: Import required libraries
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

from src.modules.attention import ScaledDotProductAttention

# Visualization setup
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)

print("‚úÖ Imports successful!")
print(f"PyTorch version: {torch.__version__}")

---

## 1. The Attention Revolution <a id="revolution"></a>

### The Problem with RNNs

Before Transformers, sequence modeling was dominated by RNNs (LSTM, GRU):

```
Input:  "The cat sat on the mat"
RNN:    ‚Üí  ‚Üí  ‚Üí  ‚Üí  ‚Üí  ‚Üí 
        h‚ÇÅ h‚ÇÇ h‚ÇÉ h‚ÇÑ h‚ÇÖ h‚ÇÜ
```

**Problems:**
1. ‚ùå **Sequential Processing**: Must process tokens one-by-one (slow, can't parallelize)
2. ‚ùå **Long-Range Dependencies**: Information from early tokens gets diluted
3. ‚ùå **Fixed Context**: Hidden state must compress entire history

### The Attention Solution

**Key Idea:** Let each token directly attend to (look at) all other tokens!

```
"The cat sat on the mat"
 ‚Üï   ‚Üï   ‚Üï   ‚Üï   ‚Üï   ‚Üï
 ‚Üê‚Üí  ‚Üê‚Üí  ‚Üê‚Üí  ‚Üê‚Üí  ‚Üê‚Üí  ‚Üê‚Üí  (Each token can attend to any other)
```

**Benefits:**
- ‚úÖ **Parallel Processing**: All tokens processed simultaneously
- ‚úÖ **Direct Connections**: Any token can attend to any other (no information loss)
- ‚úÖ **Dynamic Context**: Different queries attend differently

### The Intuition

Attention answers: **"Where should I look to understand this word?"**

Example: *"The animal didn't cross the street because **it** was too tired"*

When processing "**it**", attention might:
- Look strongly at "**animal**" (high attention weight)
- Look weakly at "street" (low attention weight)
- Determine "it" = "animal", not "street"

This is **learned automatically** from data! üéØ

In [None]:
# Visualize the difference between RNN and Attention
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# RNN: Sequential processing
ax1 = axes[0]
tokens = ["The", "cat", "sat", "on", "mat"]
for i, token in enumerate(tokens):
    ax1.text(i, 0.5, token, ha='center', va='center', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightblue'))
    if i < len(tokens) - 1:
        ax1.arrow(i+0.3, 0.5, 0.4, 0, head_width=0.1, head_length=0.1, fc='black')

ax1.set_xlim(-0.5, len(tokens))
ax1.set_ylim(0, 1)
ax1.axis('off')
ax1.set_title('RNN: Sequential Processing\n(One token at a time)', fontsize=14, weight='bold')

# Attention: All-to-all connections
ax2 = axes[1]
positions = np.arange(len(tokens))
for i, token in enumerate(tokens):
    ax2.text(i, 0.5, token, ha='center', va='center', fontsize=14, weight='bold',
            bbox=dict(boxstyle='round', facecolor='lightgreen'))
    # Draw connections to all other tokens
    for j in range(len(tokens)):
        if i != j:
            ax2.plot([i, j], [0.5, 0.5], 'gray', alpha=0.2, linewidth=1)

ax2.set_xlim(-0.5, len(tokens))
ax2.set_ylim(0, 1)
ax2.axis('off')
ax2.set_title('Attention: All-to-All Connections\n(Parallel processing)', fontsize=14, weight='bold')

plt.tight_layout()
plt.show()

print("üîë Key Difference:")
print("  RNN: Token 5 must pass through tokens 1-4 to see token 0")
print("  Attention: Token 5 directly attends to token 0 (no intermediary!)")

---

## 2. Query, Key, Value Intuition <a id="qkv"></a>

### The Database Analogy

Think of attention like a **database lookup**:

- **Query (Q)**: "What am I looking for?" (your search query)
- **Key (K)**: "What does each item offer?" (database indices)
- **Value (V)**: "What information does each item contain?" (database content)

### Real-World Example

Imagine searching a library:

```python
# You want to learn about "machine learning"
Query = "machine learning concepts"

# Books in the library:
Keys = [
    "deep learning and neural networks",  # High similarity!
    "cooking recipes for beginners",       # Low similarity
    "artificial intelligence overview",    # Medium similarity
    "gardening tips and tricks"            # Low similarity
]

Values = [
    "<content of deep learning book>",
    "<content of cooking book>",
    "<content of AI book>",
    "<content of gardening book>"
]

# Attention computes:
# 1. Similarity between Query and each Key
# 2. Weighted sum of Values based on similarities
```

### In Transformers

For a sequence "The cat sat":

When processing "sat":
- **Query**: "sat" asks "what words are relevant to me?"
- **Keys**: Each word ("The", "cat", "sat") offers what it represents
- **Values**: The actual semantic content of each word

**Attention determines:** "sat" should pay attention to "cat" (the subject!)

### Mathematical Projection

Q, K, V are **learned linear projections** of the input:

$$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$

Where:
- $X \in \mathbb{R}^{n \times d_{model}}$ (input sequence)
- $W^Q, W^K, W^V \in \mathbb{R}^{d_{model} \times d_k}$ (learned weight matrices)
- The model learns **what to query**, **what to key on**, **what to value**

In [None]:
# Demonstrate Q, K, V projections
batch_size = 1
seq_len = 5
d_model = 64
d_k = 64

# Simulated input (e.g., embedded tokens)
X = torch.randn(batch_size, seq_len, d_model)

# Learnable projection matrices
W_q = nn.Linear(d_model, d_k, bias=False)
W_k = nn.Linear(d_model, d_k, bias=False)
W_v = nn.Linear(d_model, d_k, bias=False)

# Project to Q, K, V
Q = W_q(X)
K = W_k(X)
V = W_v(X)

print("üîß Query, Key, Value Projections\n")
print(f"Input X shape: {X.shape}")
print(f"  ‚Üí [batch_size, seq_len, d_model]\n")

print(f"Query Q shape: {Q.shape}")
print(f"Key K shape: {K.shape}")
print(f"Value V shape: {V.shape}")
print(f"  ‚Üí All: [batch_size, seq_len, d_k]\n")

print("üí° Interpretation:")
print("  - Each position has a Query: 'What do I need?'")
print("  - Each position has a Key: 'What do I offer?'")
print("  - Each position has a Value: 'Here's my content'")
print("\n  The model learns these projections during training!")

---

## 3. Scaled Dot-Product Attention Formula <a id="formula"></a>

### The Complete Formula

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Let's break this down step by step:

### Step 1: Compute Attention Scores

$$\text{scores} = QK^T$$

- Dot product between queries and keys
- High score = high similarity
- Shape: $[n \times n]$ (every query to every key)

### Step 2: Scale the Scores

$$\text{scaled\_scores} = \frac{QK^T}{\sqrt{d_k}}$$

**Why scale?**
- Dot products grow with dimension $d_k$
- Large values ‚Üí softmax saturates ‚Üí tiny gradients
- Dividing by $\sqrt{d_k}$ normalizes variance

**Example:**
```python
d_k = 64
raw_score = 100  # Too large!
scaled_score = 100 / ‚àö64 = 100 / 8 = 12.5  # Better range
```

### Step 3: Apply Softmax

$$\text{weights} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$$

- Converts scores to probabilities
- Each row sums to 1
- High scores ‚Üí high attention weights

### Step 4: Weighted Sum of Values

$$\text{output} = \text{weights} \cdot V$$

- Multiply attention weights by values
- Each position becomes a weighted combination of all values
- Positions with high attention contribute more

### Why This Works

1. **Dot product** measures similarity (Q ¬∑ K)
2. **Scaling** prevents gradient issues
3. **Softmax** creates a distribution (interpretable weights)
4. **Weighted sum** aggregates relevant information

### Computational Complexity

- Attention matrix: $O(n^2 \cdot d_k)$
- Space: $O(n^2)$ (stores attention weights)
- This is why very long sequences are challenging!

In [None]:
# Manual implementation to understand each step
def manual_attention_step_by_step(Q, K, V, mask=None):
    """
    Implement scaled dot-product attention with detailed output
    """
    d_k = Q.size(-1)
    
    print("üìê Step-by-Step Attention Computation\n")
    print("="*70)
    
    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))
    print(f"\n1Ô∏è‚É£ Compute scores: Q @ K^T")
    print(f"   Shape: {scores.shape}")
    print(f"   Range: [{scores.min():.2f}, {scores.max():.2f}]")
    print(f"   Mean: {scores.mean():.2f}, Std: {scores.std():.2f}")
    
    # Step 2: Scale by sqrt(d_k)
    scaled_scores = scores / np.sqrt(d_k)
    print(f"\n2Ô∏è‚É£ Scale by ‚àö{d_k} = {np.sqrt(d_k):.2f}")
    print(f"   Range: [{scaled_scores.min():.2f}, {scaled_scores.max():.2f}]")
    print(f"   Mean: {scaled_scores.mean():.2f}, Std: {scaled_scores.std():.2f}")
    print(f"   ‚úì Variance normalized!")
    
    # Step 3: Apply mask (if provided)
    if mask is not None:
        scaled_scores = scaled_scores.masked_fill(mask == 0, float('-inf'))
        print(f"\n3Ô∏è‚É£ Apply mask (set masked positions to -inf)")
        print(f"   Masked positions will get 0 attention after softmax")
    
    # Step 4: Apply softmax
    attn_weights = F.softmax(scaled_scores, dim=-1)
    print(f"\n{'4Ô∏è‚É£' if mask is None else '4Ô∏è‚É£'} Apply softmax (convert to probabilities)")
    print(f"   Shape: {attn_weights.shape}")
    print(f"   Range: [0.0, 1.0]")
    print(f"   Each row sums to: {attn_weights[0].sum(dim=-1).mean():.4f} (‚âà 1.0)")
    
    # Step 5: Weighted sum of values
    output = torch.matmul(attn_weights, V)
    print(f"\n5Ô∏è‚É£ Compute weighted sum: attention_weights @ V")
    print(f"   Output shape: {output.shape}")
    print(f"   Each position is now a weighted combination of all values!")
    
    print("\n" + "="*70)
    
    return output, attn_weights


# Create example data
batch_size = 1
seq_len = 6
d_k = 64

Q = torch.randn(batch_size, seq_len, d_k)
K = torch.randn(batch_size, seq_len, d_k)
V = torch.randn(batch_size, seq_len, d_k)

# Run manual attention
output, weights = manual_attention_step_by_step(Q, K, V)

---

## 4. Implementation from Scratch <a id="implementation"></a>

Now let's implement the `ScaledDotProductAttention` class and use our module:

In [None]:
# Use our implementation
attention = ScaledDotProductAttention(dropout=0.1)

# Test it
output, attn_weights = attention(Q, K, V)

print("\nüîß Using ScaledDotProductAttention Module\n")
print(f"Input shapes:")
print(f"  Q: {Q.shape}")
print(f"  K: {K.shape}")
print(f"  V: {V.shape}")

print(f"\nOutput shapes:")
print(f"  Output: {output.shape}")
print(f"  Attention weights: {attn_weights.shape}")

print(f"\n‚úÖ Attention successfully computed!")
print(f"   Each position now contains information from all positions")
print(f"   weighted by their relevance (attention weights)")

In [None]:
# Implement from scratch for learning
class SimpleAttention(nn.Module):
    """Minimal attention implementation for educational purposes"""
    
    def __init__(self):
        super().__init__()
    
    def forward(self, Q, K, V, mask=None):
        # Get dimension
        d_k = Q.size(-1)
        
        # 1. Compute scores: Q @ K^T
        scores = torch.matmul(Q, K.transpose(-2, -1))
        
        # 2. Scale by sqrt(d_k)
        scores = scores / np.sqrt(d_k)
        
        # 3. Apply mask (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # 4. Apply softmax
        attn_weights = F.softmax(scores, dim=-1)
        
        # 5. Weighted sum of values
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights


# Test our implementation
simple_attn = SimpleAttention()
output2, weights2 = simple_attn(Q, K, V)

print("‚úÖ Custom attention implementation works!")
print(f"   Output shape: {output2.shape}")
print(f"   Weights shape: {weights2.shape}")

# Verify it matches our module (approximately, due to dropout)
print(f"\nüìä Consistency check:")
print(f"   Both implementations produce same shapes: {output.shape == output2.shape}")

---

## 5. Visualization & Analysis <a id="visualization"></a>

Let's visualize attention patterns to understand what the model is learning:

In [None]:
# Create a more interpretable example
# Simulate a sentence: "The cat sat on the mat"
seq_len = 6
d_k = 64

# Create Q, K, V with some structure
torch.manual_seed(42)
Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_k)

# Make "cat" and "sat" more similar (subject-verb relationship)
K[0, 2] = K[0, 1] * 0.7 + K[0, 2] * 0.3  # "sat" key similar to "cat" key

# Compute attention
attention = ScaledDotProductAttention(dropout=0.0)  # No dropout for visualization
output, attn_weights = attention(Q, K, V)

# Visualize
tokens = ["The", "cat", "sat", "on", "the", "mat"]

plt.figure(figsize=(10, 8))
sns.heatmap(attn_weights[0].detach().numpy(), 
            annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=tokens,
            yticklabels=tokens,
            cbar_kws={'label': 'Attention Weight'})

plt.title('Attention Weight Matrix\n(Each row shows where that token attends)', fontsize=14, weight='bold')
plt.xlabel('Keys (attending TO)', fontsize=12)
plt.ylabel('Queries (attending FROM)', fontsize=12)
plt.tight_layout()
plt.show()

print("üìä Interpretation:")
print("  - Darker colors = stronger attention")
print("  - Each ROW is a query's attention distribution")
print("  - Each row sums to 1.0 (probability distribution)")
print(f"\n  Example: '{tokens[2]}' attends most to '{tokens[attn_weights[0, 2].argmax().item()]}'")
print(f"  (Attention weight: {attn_weights[0, 2].max():.3f})")

In [None]:
# Visualize how scaling affects attention
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Create scores with different scales
Q_test = torch.randn(1, 4, 64)
K_test = torch.randn(1, 4, 64)

scores = torch.matmul(Q_test, K_test.transpose(-2, -1))[0]

# Without scaling
ax1 = axes[0]
weights_no_scale = F.softmax(scores, dim=-1)
sns.heatmap(weights_no_scale.detach().numpy(), annot=True, fmt='.3f', cmap='YlOrRd', ax=ax1, vmin=0, vmax=1)
ax1.set_title('No Scaling\n(Softmax on raw scores)', fontsize=12, weight='bold')
ax1.set_xlabel('Key')
ax1.set_ylabel('Query')

# With proper scaling
ax2 = axes[1]
weights_scaled = F.softmax(scores / np.sqrt(64), dim=-1)
sns.heatmap(weights_scaled.detach().numpy(), annot=True, fmt='.3f', cmap='YlOrRd', ax=ax2, vmin=0, vmax=1)
ax2.set_title(f'Scaled by ‚àö64 = 8\n(Better distribution)', fontsize=12, weight='bold')
ax2.set_xlabel('Key')
ax2.set_ylabel('Query')

# Show the difference
ax3 = axes[2]
difference = (weights_scaled - weights_no_scale).detach().numpy()
sns.heatmap(difference, annot=True, fmt='.3f', cmap='RdBu_r', ax=ax3, center=0)
ax3.set_title('Difference\n(Scaled - Unscaled)', fontsize=12, weight='bold')
ax3.set_xlabel('Key')
ax3.set_ylabel('Query')

plt.tight_layout()
plt.show()

print("üí° Key Insight:")
print("  - Scaling prevents attention from becoming too 'peaky' (dominated by one position)")
print("  - Allows more balanced attention distribution")
print("  - Better gradients during training!")

---

## 6. DeepSeek Insights <a id="deepseek"></a>

### üî¨ DeepSeek-R1 Perspective on Scaled Dot-Product Attention

**DeepSeek-R1** research reveals deep insights about how attention enables reasoning:

#### 1. **Attention as Differentiable Memory Access**

> "Attention is not just a weighted average - it's a differentiable way to read from memory. The Query is the 'read address', Keys are 'memory indices', and Values are 'memory content'."

**Why this matters:**
- Traditional neural networks have fixed computation
- Attention allows **dynamic, input-dependent** computation
- The model learns **where to look** based on **what it needs**

#### 2. **The Scaling Factor is Critical**

> "Without proper scaling, attention patterns become overconfident early in training, leading to mode collapse and poor generalization."

**Mathematical reasoning:**
- Dot product variance grows with dimension: $\text{Var}(QK^T) \propto d_k$
- Large values ‚Üí softmax saturates ‚Üí gradients vanish
- Scaling by $\sqrt{d_k}$ ensures: $\text{Var}(\frac{QK^T}{\sqrt{d_k}}) \approx 1$

#### 3. **Information Routing**

> "Attention is the Transformer's way of routing information. Each layer decides: 'Which information from which positions should flow where?'"

**In practice:**
- Early layers: Local patterns (adjacent words)
- Middle layers: Syntactic relationships (subject-verb)
- Late layers: Semantic relationships (reasoning steps)

#### 4. **Why Dot Product?**

There are other similarity measures (cosine, Euclidean), but dot product wins because:

1. **Efficient**: Matrix multiplication is highly optimized on GPUs
2. **Differentiable**: Smooth gradients for learning
3. **Expressive**: Can represent both similarity AND magnitude
4. **Stable**: With proper scaling

#### 5. **The Softmax Distribution**

> "Softmax creates a 'soft' selection mechanism. Instead of picking the single best match (hard attention), we get a distribution over all matches. This is crucial for gradient flow."

**Benefits:**
- Differentiable (hard attention isn't)
- Allows weighted combinations
- Temperature-like behavior (sharper vs softer)

---

### DeepSeek's Training Insights

During DeepSeek-R1 training, researchers observed:

1. **Attention patterns emerge gradually**
   - Random at initialization
   - Local patterns first (adjacent tokens)
   - Long-range patterns later (complex reasoning)

2. **Different heads specialize**
   - Some focus on syntax
   - Some focus on semantics
   - Some focus on specific linguistic phenomena

3. **Reasoning requires multi-hop attention**
   - Single attention layer: direct associations
   - Multiple layers: chains of reasoning
   - Example: A‚ÜíB (layer 1), B‚ÜíC (layer 2), conclude A‚ÜíC

In [None]:
# Demonstrate the effect of different similarity measures
def compare_similarity_measures(Q, K):
    """
    Compare dot product, cosine similarity, and Euclidean distance
    """
    # Dot product (what we use)
    dot_product = torch.matmul(Q, K.transpose(-2, -1))[0]
    
    # Cosine similarity
    Q_norm = Q / Q.norm(dim=-1, keepdim=True)
    K_norm = K / K.norm(dim=-1, keepdim=True)
    cosine_sim = torch.matmul(Q_norm, K_norm.transpose(-2, -1))[0]
    
    # Euclidean distance (convert to similarity)
    Q_expanded = Q.unsqueeze(2)  # [1, n, 1, d]
    K_expanded = K.unsqueeze(1)  # [1, 1, n, d]
    euclidean_dist = torch.norm(Q_expanded - K_expanded, dim=-1)[0]
    euclidean_sim = -euclidean_dist  # Negative because smaller distance = more similar
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    measures = [
        (dot_product, 'Dot Product\n(Used in Transformers)', 'RdBu_r'),
        (cosine_sim, 'Cosine Similarity', 'RdBu_r'),
        (euclidean_sim, 'Negative Euclidean Distance', 'RdBu_r')
    ]
    
    for ax, (scores, title, cmap) in zip(axes, measures):
        # Apply softmax for fair comparison
        weights = F.softmax(scores, dim=-1)
        sns.heatmap(weights.detach().numpy(), annot=True, fmt='.3f', 
                   cmap='YlOrRd', ax=ax, vmin=0, vmax=1)
        ax.set_title(title, fontsize=12, weight='bold')
        ax.set_xlabel('Key')
        ax.set_ylabel('Query')
    
    plt.tight_layout()
    plt.show()

# Test with small example
Q_test = torch.randn(1, 4, 64)
K_test = torch.randn(1, 4, 64)

compare_similarity_measures(Q_test, K_test)

print("üîç Comparison:")
print("  - Dot product: Fast, expressive, but needs scaling")
print("  - Cosine: Normalized, but loses magnitude information")
print("  - Euclidean: Distance-based, less efficient for high dimensions")
print("\n  ‚úÖ Dot product + scaling wins for Transformers!")

---

## üéØ Summary & Key Takeaways

### What We Learned

1. **The Attention Revolution**
   - Replaced sequential RNNs with parallel attention
   - Direct connections between all positions
   - No information bottleneck

2. **Query, Key, Value**
   - Q: "What am I looking for?"
   - K: "What do I offer?"
   - V: "Here's my content"
   - Learned projections of input

3. **The Formula**
   $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
   - Dot product for similarity
   - Scale to prevent gradient issues
   - Softmax for probability distribution
   - Weighted sum for output

4. **Why It Works**
   - Dynamic information routing
   - Differentiable memory access
   - Learned attention patterns
   - Enables reasoning through information flow

5. **DeepSeek Insights**
   - Attention = differentiable read from memory
   - Scaling is critical for training stability
   - Patterns emerge from simple to complex
   - Multi-hop reasoning requires depth

### Next Steps

In **Tutorial 3: Multi-Head Attention & Masking**, we'll learn:
- How multiple attention heads provide different perspectives
- Parallel attention computation
- Masking strategies (padding, causal)
- Different attention patterns (self, cross)

The scaled dot-product attention we learned is the **building block** for multi-head attention!

---

## üß™ Exercises

1. **Implement Without Scaling**: Remove the $\sqrt{d_k}$ scaling and observe the effect on attention patterns

2. **Experiment with Dimensions**: Try different $d_k$ values (16, 32, 128, 256) and see how it affects attention

3. **Create Structured Patterns**: Design Q, K, V matrices to create specific attention patterns (e.g., each position attends only to itself)

4. **Visualize Gradients**: Compute gradients and visualize how they flow through the attention mechanism

5. **Alternative Similarity**: Implement attention using cosine similarity instead of dot product - what changes?