# Understanding Attention: The Heart of Transformers

Attention allows models to focus on relevant parts of input when processing each position. Instead of processing sequences step-by-step like RNNs, attention connects any two positions directly.

## Core Formula
`Attention(Q,K,V) = softmax(QK^T / √d_k)V`

Think of it like a database lookup:
- **Queries (Q)**: What you're searching for
- **Keys (K)**: Search index 
- **Values (V)**: Actual data returned
- **Attention weights**: How relevant each key is to each query

## Environment Setup

Import required libraries for attention implementation and visualization.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)

## Test Basic Attention

Create simple test data where attention patterns should be interpretable.

In [None]:
## Attention Visualization

Visualize attention patterns on a realistic word sequence.

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    
    # Step 1: Compute similarity scores (dot product)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    
    # Step 2: Scale for numerical stability
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    # Step 3: Apply mask if provided (for causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Step 4: Convert to probabilities
    attention_weights = F.softmax(scores, dim=-1)
    
    # Step 5: Apply to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

def create_causal_mask(seq_len):
    return torch.tril(torch.ones(seq_len, seq_len))

seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causal mask (1 = can attend, 0 = masked):")
print(causal_mask.numpy())

# Apply causal mask
output_causal, attn_causal = scaled_dot_product_attention(
    embeddings, embeddings, embeddings, mask=causal_mask.unsqueeze(0)
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

sns.heatmap(attn_weights[0].numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax1)
ax1.set_title('Regular Self-Attention')

sns.heatmap(attn_causal[0].numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax2)
ax2.set_title('Causal Self-Attention')

plt.tight_layout()
plt.show()

print("Causal attention: each position can only attend to previous positions.")

## Multi-Head Attention Implementation

Multiple attention heads capture different types of relationships in parallel.

In [None]:
## Summary

You've mastered attention - the core mechanism that makes transformers work!

**Key Concepts**:
- **Attention formula**: `softmax(QK^T / √d_k)V` - similarity-based weighted combination
- **Scaling**: `√d_k` prevents softmax saturation for numerical stability
- **Causal masking**: Prevents future information leakage in language modeling  
- **Multi-head**: Parallel attention heads capture different relationship types

**Why Revolutionary**: Direct connections between any positions enable parallelization and long-range dependencies.

**Next**: Combine attention with feed-forward networks to build complete transformer blocks!

In [ ]:
# Create word embeddings for "The cat sat down"
embeddings = torch.tensor([
    [1.0, 0.5, 0.0, 0.5],  # "The"
    [0.0, 1.0, 1.0, 0.0],  # "cat" 
    [0.0, 0.0, 0.5, 1.0],  # "sat"
    [0.5, 0.0, 0.0, 1.0],  # "down"
]).unsqueeze(0)

words = ["The", "cat", "sat", "down"]
output, attn_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)

plt.figure(figsize=(8, 6))
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True, fmt='.3f',
    xticklabels=words, yticklabels=words,
    cmap='Blues'
)
plt.title('Self-Attention Weights')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.show()

print("Each row shows how much each word attends to other words.")

## Multi-Head Attention Implementation

Multiple attention heads capture different types of relationships in parallel.

In [None]:
def create_simple_sequence_data():
    """
    Create a simple sequence where attention patterns should be interpretable.
    """
    # Create embeddings for words: "The", "cat", "sat", "down"
    seq_len, d_model = 4, 6
    
    # Manually create embeddings that should have interesting attention patterns
    embeddings = torch.tensor([
        [1.0, 0.5, 0.0, 0.5, 0.0, 0.0],  # "The" - article
        [0.0, 1.0, 1.0, 0.0, 0.5, 0.0],  # "cat" - noun
        [0.0, 0.0, 0.5, 1.0, 1.0, 0.5],  # "sat" - verb
        [0.5, 0.0, 0.0, 0.5, 1.0, 1.0],  # "down" - adverb
    ]).unsqueeze(0)  # Add batch dimension
    
    words = ["The", "cat", "sat", "down"]
    
    return embeddings, words

# Create interpretable data
embeddings, words = create_simple_sequence_data()
print(f"Embeddings shape: {embeddings.shape}")
print(f"Words: {words}")

# Use embeddings as Q, K, V for self-attention
output, attn_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)

# Visualize attention weights
plt.figure(figsize=(8, 6))
sns.heatmap(
    attn_weights[0].detach().numpy(),
    annot=True,
    fmt='.3f',
    xticklabels=words,
    yticklabels=words,
    cmap='Blues',
    cbar_kws={'label': 'Attention Weight'}
)
plt.title('Self-Attention Weights')
plt.xlabel('Keys (attending to)')
plt.ylabel('Queries (attending from)')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("Each row shows how much each word attends to other words.")
print("Higher values (darker blue) indicate stronger attention.")

In [None]:
def create_causal_mask(seq_len):
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

seq_len = 4
causal_mask = create_causal_mask(seq_len)
print("Causal mask (1 = can attend, 0 = cannot attend):")
print(causal_mask.numpy())

output_causal, attn_weights_causal = scaled_dot_product_attention(
    embeddings, embeddings, embeddings, mask=causal_mask.unsqueeze(0), show_steps=False
)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

sns.heatmap(attn_weights[0].detach().numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax1)
ax1.set_title('Regular Self-Attention')

sns.heatmap(attn_weights_causal[0].detach().numpy(), annot=True, fmt='.3f',
           xticklabels=words, yticklabels=words, cmap='Blues', ax=ax2)
ax2.set_title('Causal Self-Attention')

plt.tight_layout()
plt.show()

print("Causal attention: each word can only attend to itself and previous words.")
print("Notice the upper triangle is zero - no peeking at future!")

## Summary: Attention Mastery

You've learned the core mechanism that powers all transformers!

### Key Concepts
- **Attention formula**: `softmax(QK^T / √d_k)V` - weighted average based on similarity
- **Scaling factor**: `√d_k` prevents softmax saturation for numerical stability  
- **Causal masking**: Prevents future information leakage in language modeling
- **Multi-head attention**: Parallel heads capture different relationship types

### Why Attention is Revolutionary
**Before**: RNNs processed sequences step-by-step, limiting parallelization
**After**: Attention connects any two positions directly, enabling parallelization

### Applications
- **Self-attention**: Each position attends to all positions in same sequence
- **Cross-attention**: Queries from one sequence, keys/values from another (e.g., translation)
- **Causal attention**: For autoregressive language modeling

### Next Steps
Now you understand attention! Next, we'll see how it combines with other components to build complete transformer blocks.

The foundation is solid - let's build transformers! 🚀