# Attention Mechanisms: The Foundation of Transformers

Attention is the breakthrough that powers modern AI systems like GPT, BERT, and image generation models. It allows networks to focus on relevant parts of the input.

## What You'll Learn:
1. **The Attention Problem** - Why we need attention
2. **Attention Basics** - Query, Key, Value
3. **Scaled Dot-Product Attention** - The core mechanism
4. **Self-Attention** - Attending to your own sequence
5. **Multi-Head Attention** - Multiple attention patterns
6. **Implementation** - From scratch in NumPy and PyTorch

**Key Idea:** Attention lets the model decide which parts of the input to focus on for each output.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)

sns.set_style('whitegrid')

## 1. The Attention Problem

### Traditional RNNs/LSTMs:
Process sequences sequentially and compress everything into a fixed-size hidden state.

**Problems:**
1. **Information bottleneck**: Long sequences lose information
2. **No parallelization**: Must process sequentially
3. **Fixed context**: Can't focus on specific parts

### Example: Translation
**English:** "The cat sat on the mat"

**French:** "Le chat s'est assis sur le tapis"

When translating "chat", we should focus on "cat" in the input, not "mat"!

**Attention** solves this by letting the model dynamically focus on relevant input positions.

## 2. Attention Basics: Query, Key, Value

Attention is like a **lookup mechanism** similar to a database:

- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What does each item represent?"
- **Value (V)**: "What is the actual content?"

### Real-World Analogy:
Imagine searching your memory for "that Italian restaurant":
- **Query**: "Italian restaurant I liked"
- **Keys**: Properties of all restaurants you know
- **Values**: Full information about each restaurant

Your brain compares the query against keys to find matching restaurants, then retrieves their values.

### Mathematical Formulation:

**Attention(Q, K, V) = weighted sum of Values**

Where weights depend on how well Query matches each Key.

## 3. Scaled Dot-Product Attention

The most common attention mechanism.

### Formula:

$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

### Steps:

1. **Compute scores**: $S = QK^T$ (how well does Q match each K?)
2. **Scale**: $S_{scaled} = \frac{S}{\sqrt{d_k}}$ (prevent large values)
3. **Normalize**: $A = \text{softmax}(S_{scaled})$ (get probabilities)
4. **Weighted sum**: $\text{Output} = AV$ (combine values)

### Why scale by $\sqrt{d_k}$?
When $d_k$ is large, dot products grow large in magnitude, pushing softmax into regions with small gradients. Scaling prevents this.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention
    
    Args:
        Q: Query matrix (seq_len_q, d_k)
        K: Key matrix (seq_len_k, d_k)
        V: Value matrix (seq_len_v, d_v) where seq_len_v == seq_len_k
        mask: Optional mask to ignore certain positions
    
    Returns:
        output: Attention output (seq_len_q, d_v)
        attention_weights: Attention distribution (seq_len_q, seq_len_k)
    """
    d_k = Q.shape[-1]
    
    # 1. Compute attention scores
    scores = np.matmul(Q, K.T)  # (seq_len_q, seq_len_k)
    
    # 2. Scale
    scores = scores / np.sqrt(d_k)
    
    # 3. Apply mask (if provided)
    if mask is not None:
        scores = scores + (mask * -1e9)  # add large negative to masked positions
    
    # 4. Softmax to get attention weights
    attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
    
    # 5. Apply attention to values
    output = np.matmul(attention_weights, V)  # (seq_len_q, d_v)
    
    return output, attention_weights

print("Scaled dot-product attention implemented!")

### 3.1 Simple Example

In [None]:
# Create simple example
# Imagine 3 input tokens, each with 4-dimensional embeddings
seq_len = 3
d_model = 4

# For simplicity, Q = K = V (self-attention)
X = np.array([[1.0, 0.0, 1.0, 0.0],  # Token 1
              [0.0, 2.0, 0.0, 2.0],  # Token 2
              [1.0, 1.0, 1.0, 1.0]]) # Token 3

Q = K = V = X

# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)

print("Input tokens (Q=K=V):")
print(X)
print(f"\nAttention weights shape: {attention_weights.shape}")
print("\nAttention weights (how much each output attends to each input):")
print(attention_weights)
print("\nOutput (weighted combination of values):")
print(output)

# Visualize attention weights
plt.figure(figsize=(6, 5))
sns.heatmap(attention_weights, annot=True, fmt='.3f', cmap='Blues',
            xticklabels=['Token 1', 'Token 2', 'Token 3'],
            yticklabels=['Token 1', 'Token 2', 'Token 3'])
plt.xlabel('Attending to (Keys)')
plt.ylabel('Attending from (Queries)')
plt.title('Attention Weights Matrix')
plt.show()

print("\nInterpretation:")
print("- Rows sum to 1 (each query attends to all keys with total weight = 1)")
print("- Higher values = stronger attention")

## 4. Self-Attention

**Self-attention** is when a sequence attends to itself. This is the key mechanism in Transformers!

### Process:
1. Start with input sequence $X = [x_1, x_2, ..., x_n]$
2. Create Q, K, V by linear projections:
   - $Q = XW_Q$
   - $K = XW_K$
   - $V = XW_V$
3. Apply attention: $\text{Attention}(Q, K, V)$

### Why is this powerful?
Each token can look at **every other token** in the sequence to understand context!

Example: "The animal didn't cross the street because it was too tired"
- "it" should attend strongly to "animal", not "street"

In [None]:
class SelfAttention:
    """
    Self-attention layer
    """
    def __init__(self, d_model, d_k, d_v):
        """
        Args:
            d_model: Input/output dimension
            d_k: Dimension of queries and keys
            d_v: Dimension of values
        """
        self.d_k = d_k
        
        # Initialize projection matrices
        self.W_q = np.random.randn(d_model, d_k) / np.sqrt(d_model)
        self.W_k = np.random.randn(d_model, d_k) / np.sqrt(d_model)
        self.W_v = np.random.randn(d_model, d_v) / np.sqrt(d_model)
    
    def forward(self, X):
        """
        Forward pass
        
        Args:
            X: Input sequence (seq_len, d_model)
        
        Returns:
            output: Self-attention output (seq_len, d_v)
            attention_weights: Attention distribution
        """
        # Project to Q, K, V
        Q = X @ self.W_q  # (seq_len, d_k)
        K = X @ self.W_k  # (seq_len, d_k)
        V = X @ self.W_v  # (seq_len, d_v)
        
        # Apply scaled dot-product attention
        output, attention_weights = scaled_dot_product_attention(Q, K, V)
        
        return output, attention_weights

# Test
np.random.seed(42)
seq_len = 5
d_model = 8
d_k = d_v = 8

# Random input sequence
X = np.random.randn(seq_len, d_model)

# Create self-attention layer
self_attn = SelfAttention(d_model, d_k, d_v)
output, attn_weights = self_attn.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

# Visualize
plt.figure(figsize=(7, 6))
sns.heatmap(attn_weights, annot=True, fmt='.2f', cmap='viridis',
            xticklabels=range(1, seq_len+1),
            yticklabels=range(1, seq_len+1))
plt.xlabel('Attending to (position)')
plt.ylabel('Attending from (position)')
plt.title('Self-Attention Weights')
plt.show()

## 5. Multi-Head Attention

Instead of performing a single attention, use **multiple attention heads** in parallel!

### Why Multiple Heads?
Different heads can learn to attend to different aspects:
- Head 1: Syntactic relationships
- Head 2: Semantic relationships
- Head 3: Long-range dependencies
- etc.

### Formula:

$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$

Where:
- $h$ = number of heads
- Each head has its own $W_i^Q, W_i^K, W_i^V$ matrices
- $W^O$ projects concatenated heads back to $d_{model}$

In [None]:
class MultiHeadAttention:
    """
    Multi-head attention mechanism
    """
    def __init__(self, d_model, num_heads):
        """
        Args:
            d_model: Model dimension
            num_heads: Number of attention heads
        """
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # dimension per head
        
        # Projection matrices for all heads (combined)
        self.W_q = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_k = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_v = np.random.randn(d_model, d_model) / np.sqrt(d_model)
        self.W_o = np.random.randn(d_model, d_model) / np.sqrt(d_model)
    
    def split_heads(self, x):
        """
        Split into multiple heads
        
        Args:
            x: (seq_len, d_model)
        
        Returns:
            (num_heads, seq_len, d_k)
        """
        seq_len = x.shape[0]
        # Reshape to (seq_len, num_heads, d_k)
        x = x.reshape(seq_len, self.num_heads, self.d_k)
        # Transpose to (num_heads, seq_len, d_k)
        return x.transpose(1, 0, 2)
    
    def combine_heads(self, x):
        """
        Combine multiple heads
        
        Args:
            x: (num_heads, seq_len, d_k)
        
        Returns:
            (seq_len, d_model)
        """
        # Transpose to (seq_len, num_heads, d_k)
        x = x.transpose(1, 0, 2)
        seq_len = x.shape[0]
        # Reshape to (seq_len, d_model)
        return x.reshape(seq_len, self.d_model)
    
    def forward(self, X):
        """
        Forward pass
        
        Args:
            X: Input (seq_len, d_model)
        
        Returns:
            output: (seq_len, d_model)
            attention_weights: list of attention matrices for each head
        """
        # Project to Q, K, V
        Q = X @ self.W_q
        K = X @ self.W_k
        V = X @ self.W_v
        
        # Split into multiple heads
        Q_heads = self.split_heads(Q)  # (num_heads, seq_len, d_k)
        K_heads = self.split_heads(K)
        V_heads = self.split_heads(V)
        
        # Apply attention for each head
        outputs = []
        attention_weights = []
        
        for i in range(self.num_heads):
            output_i, attn_i = scaled_dot_product_attention(
                Q_heads[i], K_heads[i], V_heads[i]
            )
            outputs.append(output_i)
            attention_weights.append(attn_i)
        
        # Stack and combine heads
        outputs = np.stack(outputs)  # (num_heads, seq_len, d_k)
        output = self.combine_heads(outputs)  # (seq_len, d_model)
        
        # Final linear projection
        output = output @ self.W_o
        
        return output, attention_weights

# Test multi-head attention
np.random.seed(42)
seq_len = 6
d_model = 8
num_heads = 2

X = np.random.randn(seq_len, d_model)

mha = MultiHeadAttention(d_model, num_heads)
output, attn_weights = mha.forward(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of attention heads: {len(attn_weights)}")
print(f"Each head attention shape: {attn_weights[0].shape}")

# Visualize each head
fig, axes = plt.subplots(1, num_heads, figsize=(12, 5))

for i in range(num_heads):
    ax = axes[i] if num_heads > 1 else axes
    sns.heatmap(attn_weights[i], annot=True, fmt='.2f', cmap='viridis', ax=ax)
    ax.set_title(f'Head {i+1} Attention')
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')

plt.tight_layout()
plt.show()

print("\nNotice: Different heads learn different attention patterns!")

## 6. PyTorch Implementation

Now let's see how clean this is in PyTorch!

In [None]:
class MultiHeadAttentionPyTorch(nn.Module):
    """
    Multi-head attention in PyTorch
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        """Split into multiple heads"""
        batch_size, seq_len, d_model = x.size()
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        # (batch, seq_len, num_heads, d_k) -> (batch, num_heads, seq_len, d_k)
        return x.transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine multiple heads"""
        batch_size, num_heads, seq_len, d_k = x.size()
        # (batch, num_heads, seq_len, d_k) -> (batch, seq_len, num_heads, d_k)
        x = x.transpose(1, 2).contiguous()
        # (batch, seq_len, num_heads, d_k) -> (batch, seq_len, d_model)
        return x.view(batch_size, seq_len, self.d_model)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """Compute scaled dot-product attention"""
        # Q, K, V: (batch, num_heads, seq_len, d_k)
        d_k = Q.size(-1)
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights
    
    def forward(self, X, mask=None):
        """
        Forward pass
        
        Args:
            X: (batch, seq_len, d_model)
            mask: Optional mask
        
        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, num_heads, seq_len, seq_len)
        """
        batch_size = X.size(0)
        
        # Linear projections
        Q = self.W_q(X)  # (batch, seq_len, d_model)
        K = self.W_k(X)
        V = self.W_v(X)
        
        # Split into heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Apply attention
        attn_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads
        attn_output = self.combine_heads(attn_output)  # (batch, seq_len, d_model)
        
        # Final projection
        output = self.W_o(attn_output)
        
        return output, attention_weights

# Test
batch_size = 2
seq_len = 6
d_model = 8
num_heads = 2

X = torch.randn(batch_size, seq_len, d_model)

mha_torch = MultiHeadAttentionPyTorch(d_model, num_heads)
output, attn_weights = mha_torch(X)

print(f"Input shape: {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

# Visualize attention for first sample
fig, axes = plt.subplots(1, num_heads, figsize=(12, 5))

for i in range(num_heads):
    ax = axes[i] if num_heads > 1 else axes
    attn_map = attn_weights[0, i].detach().numpy()  # first sample, i-th head
    sns.heatmap(attn_map, annot=True, fmt='.2f', cmap='viridis', ax=ax)
    ax.set_title(f'Head {i+1} Attention (PyTorch)')
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')

plt.tight_layout()
plt.show()

## 7. Using PyTorch's Built-in MultiheadAttention

In [None]:
# PyTorch provides nn.MultiheadAttention out of the box!
d_model = 8
num_heads = 2
seq_len = 6
batch_size = 2

# Create module
multihead_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)

# Input
X = torch.randn(batch_size, seq_len, d_model)

# Forward (Q, K, V can be the same for self-attention)
attn_output, attn_weights = multihead_attn(X, X, X, need_weights=True, average_attn_weights=False)

print(f"Input shape: {X.shape}")
print(f"Output shape: {attn_output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print("\nUsing PyTorch's built-in is much easier!")

## 8. Masked Attention (for Autoregressive Models)

In language models like GPT, we need **causal masking** so that position $i$ can only attend to positions $\leq i$.

This prevents the model from "cheating" by looking at future tokens.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask for autoregressive attention
    
    Returns:
        mask: (seq_len, seq_len) lower triangular matrix
    """
    mask = np.tril(np.ones((seq_len, seq_len)))
    return mask

# Example
seq_len = 5
mask = create_causal_mask(seq_len)

print("Causal Mask (1 = allowed, 0 = masked):")
print(mask.astype(int))

plt.figure(figsize=(6, 5))
sns.heatmap(mask, annot=True, fmt='.0f', cmap='Blues', cbar=False,
            xticklabels=range(1, seq_len+1),
            yticklabels=range(1, seq_len+1))
plt.xlabel('Can attend to position')
plt.ylabel('From position')
plt.title('Causal Attention Mask')
plt.show()

print("\nInterpretation:")
print("- Position 1 can only attend to itself")
print("- Position 2 can attend to positions 1 and 2")
print("- Position 3 can attend to positions 1, 2, and 3")
print("- etc.")

In [None]:
# Apply causal masking to attention
seq_len = 5
d_model = 4

X = np.random.randn(seq_len, d_model)
Q = K = V = X

# Create causal mask
causal_mask = 1 - create_causal_mask(seq_len)  # invert: 1 = mask, 0 = allow

# Compute attention with mask
output, attn_weights = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

print("Attention weights with causal masking:")
print(attn_weights)

plt.figure(figsize=(7, 6))
sns.heatmap(attn_weights, annot=True, fmt='.3f', cmap='viridis',
            xticklabels=range(1, seq_len+1),
            yticklabels=range(1, seq_len+1))
plt.xlabel('Attending to (position)')
plt.ylabel('Attending from (position)')
plt.title('Causal Self-Attention Weights')
plt.show()

print("\nNotice: Upper triangle is all zeros (can't attend to future!)")

## 9. Positional Encoding

**Problem:** Attention has no notion of position! The attention mechanism is **permutation invariant**.

**Solution:** Add positional information to the input embeddings.

### Sinusoidal Positional Encoding:

$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right)$

$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right)$

where:
- $pos$ = position in sequence
- $i$ = dimension index
- $d$ = model dimension

In [None]:
def positional_encoding(seq_len, d_model):
    """
    Generate sinusoidal positional encodings
    
    Args:
        seq_len: Sequence length
        d_model: Model dimension
    
    Returns:
        PE: (seq_len, d_model) positional encoding matrix
    """
    PE = np.zeros((seq_len, d_model))
    
    for pos in range(seq_len):
        for i in range(0, d_model, 2):
            # Even indices: sin
            PE[pos, i] = np.sin(pos / (10000 ** (2*i / d_model)))
            # Odd indices: cos
            if i + 1 < d_model:
                PE[pos, i+1] = np.cos(pos / (10000 ** (2*i / d_model)))
    
    return PE

# Generate positional encodings
seq_len = 50
d_model = 128

PE = positional_encoding(seq_len, d_model)

# Visualize
plt.figure(figsize=(12, 6))
sns.heatmap(PE, cmap='RdBu', center=0, cbar=True)
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding Matrix')
plt.show()

# Plot some dimensions
plt.figure(figsize=(12, 5))
for i in range(0, d_model, 16):
    plt.plot(PE[:, i], label=f'Dim {i}')
plt.xlabel('Position')
plt.ylabel('Encoding Value')
plt.title('Positional Encoding - Selected Dimensions')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Key properties:")
print("- Each position has a unique encoding")
print("- Periodic patterns at different frequencies")
print("- Model can learn to attend by relative position")

## Summary

### What We Learned:

1. **Attention Mechanism**
   - Allows dynamic focus on relevant parts of input
   - Query, Key, Value paradigm
   - Solves fixed-context problem of RNNs

2. **Scaled Dot-Product Attention**
   - Formula: $\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$
   - Computes weighted sum of values
   - Weights based on query-key similarity

3. **Self-Attention**
   - Sequence attends to itself
   - Foundation of Transformers
   - Each token can look at all other tokens

4. **Multi-Head Attention**
   - Multiple attention mechanisms in parallel
   - Learn different types of relationships
   - Concat + project to combine

5. **Causal Masking**
   - Prevent attending to future positions
   - Essential for autoregressive models (GPT)

6. **Positional Encoding**
   - Add position information
   - Sinusoidal or learned embeddings

### Why Attention is Revolutionary:

- **Parallelization**: Unlike RNNs, can process entire sequence at once
- **Long-range dependencies**: Direct connections between any positions
- **Interpretability**: Attention weights show what model focuses on
- **Flexibility**: Works for various tasks (translation, generation, etc.)

### Applications:

- **NLP**: GPT, BERT, T5 (all use attention)
- **Vision**: Vision Transformers (ViT)
- **Multimodal**: CLIP, Flamingo
- **Generation**: Stable Diffusion, Midjourney

### Next Steps:

To build a complete Transformer:
1. Multi-head attention (✓ we have this!)
2. Feed-forward networks
3. Layer normalization
4. Residual connections
5. Stack multiple layers

You now understand the core mechanism that powers modern AI!

## Practice Exercises

In [None]:
# Exercise 1: Implement cross-attention
# Cross-attention: Q from one sequence, K and V from another
# Used in encoder-decoder architectures
# Your code here


In [None]:
# Exercise 2: Visualize attention patterns on real text
# Use a pre-trained model to see what it attends to
# Example sentence: "The cat sat on the mat because it was comfortable"
# Your code here


In [None]:
# Exercise 3: Implement learned positional embeddings
# Instead of sinusoidal, use nn.Embedding for positions
# Your code here
