# Understanding the Attention Mechanism

Attention is the core innovation that makes transformers so powerful. Unlike previous models that process sequences linearly, attention allows models to directly connect any two positions in a sequence.

## Why Attention Matters

Consider this sentence: "The bank by the river was closed because the bank couldn't process loans."

How do you know which "bank" means what? Your brain automatically focuses on context:
- First "bank" + "river" → riverside bank
- Second "bank" + "loans" → financial institution

This selective focusing is exactly what attention does in neural networks.

## Core Concept

**Attention computes a weighted average of values, where weights are determined by similarity between queries and keys.**

Formula: `Attention(Q,K,V) = softmax(QK^T / √d_k)V`

Think of it like a restaurant:
- **Queries (Q)**: What the customer wants
- **Keys (K)**: What's on the menu  
- **Values (V)**: What's actually served

The waiter (attention) matches desires (Q) with options (K) to decide what to serve (V)!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Building Attention Step by Step

Let's build the attention mechanism piece by piece to understand each component.

### Step 1: The Core Idea
We want to compute how much each word should "attend to" (focus on) every other word.

### Step 2: Measuring Similarity  
We measure similarity between words using **dot products**:
- Similar words → high dot product
- Dissimilar words → low dot product

### Step 3: From Similarity to Attention Weights
1. Compute similarity scores: `scores = query · key`
2. Scale to prevent extreme values: `scores = scores / √d_k`  
3. Convert to probabilities: `weights = softmax(scores)`
4. Use weights to combine values: `output = weights · values`

### The Complete Formula
This gives us the famous attention formula:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ are the queries (what we're looking for)
- $K$ are the keys (what's available to attend to)  
- $V$ are the values (what we actually use)
- $d_k$ is the dimension of the keys (for numerical stability)

**Why the scaling?** Without $\sqrt{d_k}$, dot products become very large, making softmax too sharp (almost one-hot). Scaling keeps things smooth.

Let's implement this step by step:

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)

In [None]:
## Attention From First Principles

**The Problem**: How do we determine which parts of input are most relevant for each position?

**The Solution**: Attention mechanism in 5 steps:

1. **Compute similarity**: `scores = query · key` (dot product measures similarity)
2. **Scale for stability**: `scores = scores / √d_k` (prevents saturation)  
3. **Apply causal mask**: `scores[future] = -∞` (for language modeling)
4. **Normalize to probabilities**: `weights = softmax(scores)` (sums to 1)
5. **Weighted combination**: `output = weights · values` (final result)

**Why scaling?** Without √d_k, large dot products make softmax too sharp (almost one-hot), losing the ability to attend to multiple positions.

def scaled_dot_product_attention(Q, K, V, mask=None, show_steps=True):
    d_k = Q.size(-1)
    
    # Step 1: Compute similarity scores
    scores = torch.matmul(Q, K.transpose(-2, -1))
    if show_steps:
        print(f"Step 1 - Raw scores range: [{scores.min():.3f}, {scores.max():.3f}]")
    
    # Step 2: Scale for numerical stability
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    if show_steps:
        print(f"Step 2 - Scaled range: [{scores.min():.3f}, {scores.max():.3f}]")
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
        if show_steps:
            print(f"Step 3 - Applied causal mask")
    
    # Step 4: Convert to probabilities
    attention_weights = F.softmax(scores, dim=-1)
    if show_steps:
        print(f"Step 4 - Attention weights sum: {attention_weights.sum(dim=-1)[0, 0]:.3f}")
    
    # Step 5: Apply to values
    output = torch.matmul(attention_weights, V)
    if show_steps:
        print(f"Step 5 - Output shape: {output.shape}")
    
    return output, attention_weights

# Simple example to trace through
Q = torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.]]])
K = torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.]]])  
V = torch.tensor([[[2., 1.], [1., 3.], [3., 2.]]])

print("🔍 ATTENTION STEP-BY-STEP:")
output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"\nFinal attention weights:\n{attn_weights[0].numpy()}")
print("Notice: each query perfectly matches its corresponding key!")

In [None]:
def create_simple_sequence_data():
    """
    Create a simple sequence where attention patterns should be interpretable.
    """
    # Create embeddings for words: "The", "cat", "sat", "down"
    seq_len, d_model = 4, 6
    
    # Manually create embeddings that should have interesting attention patterns
    embeddings = torch.tensor([
        [1.0, 0.5, 0.0, 0.5, 0.0, 0.0],  # "The" - article
        [0.0, 1.0, 1.0, 0.0, 0.5, 0.0],  # "cat" - noun
        [0.0, 0.0, 0.5, 1.0, 1.0, 0.5],  # "sat" - verb
        [0.5, 0.0, 0.0, 0.5, 1.0, 1.0],  # "down" - adverb
    ]).unsqueeze(0)  # Add batch dimension
    
    words = ["The", "cat", "sat", "down"]
    
    return embeddings, words

# Create interpretable data
embeddings, words = create_simple_sequence_data()
print(f"Embeddings shape: {embeddings.shape}")
print(f"Words: {words}")

# Use embeddings as Q, K, V for self-attention
output, attn_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True,
    fmt='.3f',
    xticklabels=words,
    yticklabels=words,
    cmap='Blues',
    cbar_kws={'label': 'Attention Weight'}
)
plt.title('Self-Attention Weights')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("Each row shows how much each word attends to other words.")
print("Higher values (darker blue) indicate stronger attention.")

def create_simple_sequence_data():
    embeddings = torch.tensor([
        [1.0, 0.5, 0.0, 0.5, 0.0, 0.0],  # "The" - article
        [0.0, 1.0, 1.0, 0.0, 0.5, 0.0],  # "cat" - noun
        [0.0, 0.0, 0.5, 1.0, 1.0, 0.5],  # "sat" - verb
        [0.5, 0.0, 0.0, 0.5, 1.0, 1.0],  # "down" - adverb
    ]).unsqueeze(0)
    words = ["The", "cat", "sat", "down"]
    return embeddings, words

embeddings, words = create_simple_sequence_data()
output, attn_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings, show_steps=False)

plt.figure(figsize=(8, 6))
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True, fmt='.3f',
    xticklabels=words, yticklabels=words,
    cmap='Blues', cbar_kws={'label': 'Attention Weight'}
)
plt.title('Self-Attention Weights')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.show()

print("Each row shows how much each word attends to other words.")
print("Higher values (darker blue) = stronger attention.")

## Causal (Masked) Attention

For language modeling, we need **causal attention** - each position can only attend to previous positions, preventing the model from "cheating" by looking at future tokens.

**Why mask future positions?** In language modeling, when predicting the next word, the model shouldn't see future words that haven't been generated yet.

In [None]:
def create_causal_mask(seq_len):
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causal mask (1 = can attend, 0 = cannot attend):")
print(causal_mask.numpy())

output_causal, attn_weights_causal = scaled_dot_product_attention(
    embeddings, embeddings, embeddings, mask=causal_mask.unsqueeze(0), show_steps=False
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

sns.heatmap(attn_weights[0].detach().numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax1)
ax1.set_title('Regular Self-Attention')

sns.heatmap(attn_weights_causal[0].detach().numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax2)
ax2.set_title('Causal Self-Attention')

plt.tight_layout()
plt.show()

print("Causal attention: each word can only attend to itself and previous words.")
print("Notice the upper triangle is zero - no peeking at future!")

## Multi-Head Attention

**The Idea**: Instead of one attention computation, run multiple "heads" in parallel to capture different types of relationships.

**Why multiple heads?** Different heads can learn to focus on different aspects:
- Head 1: Subject-verb relationships  
- Head 2: Adjective-noun relationships
- Head 3: Long-range dependencies

Each head uses different learned projections of Q, K, V.

In [None]:
from src.model.attention import MultiHeadAttention

d_model, n_heads = 8, 2
mha = MultiHeadAttention(d_model, n_heads)
print(f"Multi-head attention: {n_heads} heads, {d_model // n_heads} dimensions per head")

batch_size, seq_len = 1, 4
x = torch.randn(batch_size, seq_len, d_model)
output, attention_weights = mha(x, x, x, return_attention=True)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

fig, axes = plt.subplots(1, n_heads, figsize=(12, 5))
if n_heads == 1:
    axes = [axes]

for head in range(n_heads):
    head_attention = attention_weights[0, head].detach().numpy()
    sns.heatmap(head_attention, annot=True, fmt='.3f', cmap='Blues', ax=axes[head])
    axes[head].set_title(f'Head {head + 1}')

plt.tight_layout()
plt.show()
print("Each head learns different attention patterns!")

## Summary: Attention Mastery

You've learned the core mechanism that powers all transformers!

### Key Concepts
- **Attention formula**: `softmax(QK^T / √d_k)V` - weighted average based on similarity
- **Scaling factor**: `√d_k` prevents softmax saturation for numerical stability  
- **Causal masking**: Prevents future information leakage in language modeling
- **Multi-head attention**: Parallel heads capture different relationship types

### Why Attention is Revolutionary
**Before**: RNNs processed sequences step-by-step, limiting parallelization
**After**: Attention connects any two positions directly, enabling parallelization

### Applications
- **Self-attention**: Each position attends to all positions in same sequence
- **Cross-attention**: Queries from one sequence, keys/values from another (e.g., translation)
- **Causal attention**: For autoregressive language modeling

### Next Steps
Now you understand attention! Next, we'll see how it combines with other components to build complete transformer blocks.

The foundation is solid - let's build transformers! 🚀