# Understanding the Attention Mechanism

This notebook explores the core attention mechanism that powers transformer models. We'll build intuition about how attention works and implement it step by step.

## What is Attention?

Attention allows the model to focus on different parts of the input when making predictions. The key insight is that not all parts of the input are equally important for predicting a given output.

The attention mechanism computes a weighted sum of values, where the weights are determined by the similarity between queries and keys.


In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Scaled Dot-Product Attention

The fundamental attention operation is:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ are the queries
- $K$ are the keys  
- $V$ are the values
- $d_k$ is the dimension of the keys

Let's implement this step by step:

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries [batch_size, seq_len, d_k]
        K: Keys [batch_size, seq_len, d_k]
        V: Values [batch_size, seq_len, d_v]
        mask: Optional mask [batch_size, seq_len, seq_len]
    
    Returns:
        output: [batch_size, seq_len, d_v]
        attention_weights: [batch_size, seq_len, seq_len]
    """
    d_k = Q.size(-1)
    
    # Step 1: Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1))
    print(f"Raw scores shape: {scores.shape}")
    
    # Step 2: Scale by sqrt(d_k)
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    print(f"Scaled scores range: [{scores.min():.3f}, {scores.max():.3f}]")
    
    # Step 3: Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Apply softmax
    attention_weights = F.softmax(scores, dim=-1)
    print(f"Attention weights sum: {attention_weights.sum(dim=-1)[0, 0]:.3f} (should be 1.0)")
    
    # Step 5: Apply attention to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Example usage
batch_size, seq_len, d_model = 1, 4, 8

# Create simple example data
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model) 
V = torch.randn(batch_size, seq_len, d_model)

print("Input shapes:")
print(f"Q: {Q.shape}, K: {K.shape}, V: {V.shape}")
print()

output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"\nOutput shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")

## 2. Visualizing Attention

Let's create a more interpretable example and visualize the attention patterns:

In [None]:
def create_simple_sequence_data():
    """
    Create a simple sequence where attention patterns should be interpretable.
    """
    # Create embeddings for words: "The", "cat", "sat", "down"
    seq_len, d_model = 4, 6
    
    # Manually create embeddings that should have interesting attention patterns
    embeddings = torch.tensor([
        [1.0, 0.5, 0.0, 0.5, 0.0, 0.0],  # "The" - article
        [0.0, 1.0, 1.0, 0.0, 0.5, 0.0],  # "cat" - noun
        [0.0, 0.0, 0.5, 1.0, 1.0, 0.5],  # "sat" - verb
        [0.5, 0.0, 0.0, 0.5, 1.0, 1.0],  # "down" - adverb
    ]).unsqueeze(0)  # Add batch dimension
    
    words = ["The", "cat", "sat", "down"]
    
    return embeddings, words

# Create interpretable data
embeddings, words = create_simple_sequence_data()
print(f"Embeddings shape: {embeddings.shape}")
print(f"Words: {words}")

# Use embeddings as Q, K, V for self-attention
output, attn_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True,
    fmt='.3f',
    xticklabels=words,
    yticklabels=words,
    cmap='Blues',
    cbar_kws={'label': 'Attention Weight'}
)
plt.title('Self-Attention Weights')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("Each row shows how much each word attends to other words.")
print("Higher values (darker blue) indicate stronger attention.")

## 3. Causal (Masked) Attention

For language modeling, we need causal attention where each position can only attend to previous positions (including itself). This prevents the model from "cheating" by looking at future tokens.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask where position i can only attend to positions j <= i.
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Create causal mask
seq_len = 4
causal_mask = create_causal_mask(seq_len)

print("Causal mask (1 = can attend, 0 = cannot attend):")
print(causal_mask.numpy())

# Apply causal attention to our example
output_causal, attn_weights_causal = scaled_dot_product_attention(
    embeddings, embeddings, embeddings, mask=causal_mask.unsqueeze(0)
)

# Visualize causal attention
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Regular attention
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True, fmt='.3f',
    xticklabels=words, yticklabels=words,
    cmap='Blues', ax=ax1
)
ax1.set_title('Regular Self-Attention')
ax1.set_xlabel('Keys')
ax1.set_ylabel('Queries')

# Causal attention
sns.heatmap(
    attn_weights_causal[0].detach().numpy(),
    annot=True, fmt='.3f',
    xticklabels=words, yticklabels=words,
    cmap='Blues', ax=ax2
)
ax2.set_title('Causal Self-Attention')
ax2.set_xlabel('Keys')
ax2.set_ylabel('Queries')

plt.tight_layout()
plt.show()

print("Notice how causal attention has zeros in the upper triangle.")
print("Each word can only attend to itself and previous words.")

## 4. Multi-Head Attention

Multi-head attention allows the model to attend to different types of relationships simultaneously. Instead of one attention computation, we run multiple "heads" in parallel and concatenate the results.

In [None]:
from src.model.attention import MultiHeadAttention

# Create multi-head attention layer
d_model, n_heads = 8, 2
mha = MultiHeadAttention(d_model, n_heads)

print(f"Multi-head attention with {n_heads} heads")
print(f"Model dimension: {d_model}")
print(f"Dimension per head: {d_model // n_heads}")

# Create input
batch_size, seq_len = 1, 4
x = torch.randn(batch_size, seq_len, d_model)

# Apply multi-head attention
output, attention_weights = mha(x, x, x, return_attention=True)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"(batch_size, n_heads, seq_len, seq_len)")

# Visualize attention for each head
fig, axes = plt.subplots(1, n_heads, figsize=(12, 5))
if n_heads == 1:
    axes = [axes]

for head in range(n_heads):
    head_attention = attention_weights[0, head].detach().numpy()
    
    sns.heatmap(
        head_attention,
        annot=True, fmt='.3f',
        cmap='Blues',
        ax=axes[head],
        cbar=head == n_heads - 1  # Only show colorbar for last plot
    )
    axes[head].set_title(f'Head {head + 1}')
    axes[head].set_xlabel('Keys')
    if head == 0:
        axes[head].set_ylabel('Queries')

plt.tight_layout()
plt.show()

print("\nEach head can learn to focus on different types of relationships!")

## 5. Why Attention Works

Let's explore the key benefits of the attention mechanism:

In [None]:
def demonstrate_attention_benefits():
    """
    Demonstrate key benefits of attention mechanism.
    """
    print("🔍 Key Benefits of Attention:")
    print()
    
    print("1. **Parallel Processing**")
    print("   - Unlike RNNs, all positions can be computed in parallel")
    print("   - No sequential dependency during training")
    print()
    
    print("2. **Long-Range Dependencies**")
    print("   - Direct connections between any two positions")
    print("   - No degradation over distance like in RNNs")
    print()
    
    print("3. **Interpretability**")
    print("   - Attention weights show what the model is 'looking at'")
    print("   - Helps understand model behavior")
    print()
    
    print("4. **Flexibility**")
    print("   - Can attend to any part of the sequence")
    print("   - Multiple heads can capture different relationships")
    print()
    
    # Demonstrate computational complexity
    seq_lengths = [100, 500, 1000, 2000]
    
    print("📊 Computational Complexity:")
    print("Sequence Length | Attention Ops | RNN Ops")
    print("-" * 45)
    
    for seq_len in seq_lengths:
        attention_ops = seq_len ** 2  # O(n²) for attention matrix
        rnn_ops = seq_len  # O(n) but sequential
        print(f"{seq_len:13} | {attention_ops:11,} | {rnn_ops:7,}")
    
    print()
    print("Note: Attention is O(n²) in memory and computation,")
    print("but can be fully parallelized unlike RNNs.")

demonstrate_attention_benefits()

## 6. Exercise: Build Your Own Attention

Try implementing a simple attention mechanism yourself:

In [None]:
def simple_attention_exercise():
    """
    Exercise: Implement attention step by step.
    """
    print("🎯 Exercise: Implement Simple Attention")
    print()
    
    # Given data
    Q = torch.tensor([[1.0, 0.0], [0.0, 1.0]])  # 2x2 queries
    K = torch.tensor([[1.0, 1.0], [1.0, 0.0]])  # 2x2 keys  
    V = torch.tensor([[2.0, 0.0], [0.0, 3.0]])  # 2x2 values
    
    print("Given:")
    print(f"Q = \n{Q}")
    print(f"K = \n{K}")
    print(f"V = \n{V}")
    print()
    
    print("Step 1: Compute Q @ K.T")
    scores = torch.matmul(Q, K.transpose(-2, -1))
    print(f"Scores = \n{scores}")
    print()
    
    print("Step 2: Apply softmax")
    attention_weights = F.softmax(scores, dim=-1)
    print(f"Attention weights = \n{attention_weights}")
    print(f"Row sums: {attention_weights.sum(dim=-1)} (should be [1, 1])")
    print()
    
    print("Step 3: Apply attention to values")
    output = torch.matmul(attention_weights, V)
    print(f"Output = \n{output}")
    print()
    
    print("✅ Try calculating this by hand and compare with the result!")

simple_attention_exercise()

## Summary

In this notebook, we've explored:

1. **Scaled Dot-Product Attention** - The fundamental attention operation
2. **Attention Visualization** - How to interpret attention weights
3. **Causal Attention** - Preventing information leakage in language modeling
4. **Multi-Head Attention** - Parallel attention heads for richer representations
5. **Benefits of Attention** - Why it revolutionized NLP

### Key Takeaways:

- Attention allows models to dynamically focus on relevant parts of the input
- The mechanism is based on similarity between queries and keys
- Causal masking is essential for autoregressive language modeling
- Multi-head attention captures different types of relationships
- Attention enables parallelization and handles long-range dependencies

### Next Steps:

In the next notebook, we'll see how attention is combined with other components to build complete transformer blocks!