# 09: Attention Mechanisms

**Duration:** 3-4 hours | **Difficulty:** Advanced

## Learning Objectives
- Attention mechanism fundamentals
- Multi-head attention implementation
- Attention visualization techniques
- Attention-enhanced seq2seq models

## Table of Contents
1. [Introduction to Attention](#1-introduction)
2. [Basic Attention](#2-basic-attention)
3. [Multi-Head Attention](#3-multihead)
4. [Attention Visualization](#4-visualization)
5. [Practical Exercise](#5-exercise)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
from typing import Optional, Tuple

# Import utilities
import sys
sys.path.append('../')
from utils.model_helpers import get_device, count_parameters

device = get_device("auto")
print(f"Using device: {device}")

torch.manual_seed(42)

## 1. Introduction to Attention {#1-introduction}

**Attention mechanisms** solve the information bottleneck in seq2seq models:

- **Problem**: Fixed context vector loses information in long sequences
- **Solution**: Dynamically attend to all encoder states
- **Benefit**: Better alignment and long-range dependencies

### Core Idea:
Instead of compressing entire input into one vector, create weighted combinations of all input states based on current decoder state.

In [None]:
# Visualize attention concept
def visualize_attention_concept():
    """Visualize basic attention alignment."""
    # Simulated attention weights for "How are you?" -> "I am fine"
    input_words = ['How', 'are', 'you', '?']
    output_words = ['I', 'am', 'fine']
    
    # Simulated attention matrix (output x input)
    attention_matrix = np.array([
        [0.1, 0.2, 0.6, 0.1],  # "I" attends mostly to "you"
        [0.2, 0.7, 0.1, 0.0],  # "am" attends mostly to "are"
        [0.1, 0.1, 0.2, 0.6]   # "fine" attends mostly to "?"
    ])
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(attention_matrix, 
                xticklabels=input_words, 
                yticklabels=output_words,
                annot=True, cmap='Blues', cbar=True)
    plt.title('Attention Alignment Example')
    plt.xlabel('Input Words')
    plt.ylabel('Output Words')
    plt.show()
    
    return attention_matrix

attention_example = visualize_attention_concept()
print("Each row shows where the output word 'attends' to in the input.")
print("Higher values (darker blue) indicate stronger attention.")

## 2. Basic Attention Implementation {#2-basic-attention}

Implementing additive (Bahdanau) attention mechanism.

In [None]:
class AdditiveAttention(nn.Module):
    """Additive (Bahdanau) attention mechanism."""
    
    def __init__(self, hidden_dim: int, attention_dim: int = 128):
        super().__init__()
        self.encoder_proj = nn.Linear(hidden_dim, attention_dim, bias=False)
        self.decoder_proj = nn.Linear(hidden_dim, attention_dim, bias=False)
        self.attention_v = nn.Linear(attention_dim, 1, bias=False)
        self.output_proj = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, decoder_hidden, encoder_outputs, mask=None):
        """
        Args:
            decoder_hidden: (batch_size, hidden_dim)
            encoder_outputs: (batch_size, seq_len, hidden_dim)
            mask: (batch_size, seq_len)
        
        Returns:
            context: (batch_size, hidden_dim)
            attention_weights: (batch_size, seq_len)
        """
        batch_size, seq_len, hidden_dim = encoder_outputs.shape
        
        # Project encoder and decoder states
        encoder_proj = self.encoder_proj(encoder_outputs)  # (batch, seq_len, att_dim)
        decoder_proj = self.decoder_proj(decoder_hidden).unsqueeze(1)  # (batch, 1, att_dim)
        
        # Compute attention scores
        scores = self.attention_v(torch.tanh(encoder_proj + decoder_proj)).squeeze(-1)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -float('inf'))
        
        # Attention weights
        attention_weights = F.softmax(scores, dim=1)
        
        # Context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        # Combine with decoder hidden
        combined = torch.cat([context, decoder_hidden], dim=1)
        output = torch.tanh(self.output_proj(combined))
        
        return output, attention_weights

class ScaledDotProductAttention(nn.Module):
    """Scaled dot-product attention (foundation of transformers)."""
    
    def __init__(self, d_k: int, dropout: float = 0.1):
        super().__init__()
        self.d_k = d_k
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch_size, seq_len_q, d_k)
            key: (batch_size, seq_len_k, d_k)
            value: (batch_size, seq_len_v, d_k)
            mask: (batch_size, seq_len_q, seq_len_k)
        
        Returns:
            output: (batch_size, seq_len_q, d_k)
            attention_weights: (batch_size, seq_len_q, seq_len_k)
        """
        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply to values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

# Test basic attention
hidden_dim = 256
seq_len = 8
batch_size = 2

decoder_hidden = torch.randn(batch_size, hidden_dim)
encoder_outputs = torch.randn(batch_size, seq_len, hidden_dim)

attention = AdditiveAttention(hidden_dim)
context, weights = attention(decoder_hidden, encoder_outputs)

print(f"Additive Attention Test:")
print(f"Context shape: {context.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Weights sum (should be 1.0): {weights.sum(dim=1)}")
print(f"Sample attention weights: {weights[0].detach().numpy():.3f}")

## 3. Multi-Head Attention {#3-multihead}

Multi-head attention allows attending to different representation subspaces.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism."""
    
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(self.d_k, dropout)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len_q, _ = query.shape
        
        # Linear projections and reshape for multi-head
        Q = self.w_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        if mask is not None:
            mask = mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        
        output, attention_weights = self.attention(Q, K, V, mask)
        
        # Concatenate heads
        output = output.transpose(1, 2).contiguous().view(
            batch_size, seq_len_q, self.d_model
        )
        
        # Final linear projection
        output = self.w_o(output)
        
        return output, attention_weights

# Test multi-head attention
d_model = 256
num_heads = 8
seq_len = 10

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)

output, attention_weights = mha(x, x, x)  # Self-attention

print(f"Multi-Head Attention Test:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"Parameters: {count_parameters(mha)['total']:,}")

## 4. Attention Visualization {#4-visualization}

Visualizing attention patterns to understand model behavior.

In [None]:
def visualize_attention_heads(attention_weights, input_tokens=None, max_heads=4):
    """Visualize attention patterns from multiple heads."""
    batch_idx = 0  # Show first example in batch
    attention = attention_weights[batch_idx].detach().numpy()
    
    num_heads = min(attention.shape[0], max_heads)
    
    fig, axes = plt.subplots(1, num_heads, figsize=(4 * num_heads, 4))
    if num_heads == 1:
        axes = [axes]
    
    for head in range(num_heads):
        im = axes[head].imshow(attention[head], cmap='Blues', aspect='auto')
        axes[head].set_title(f'Head {head + 1}')
        axes[head].set_xlabel('Key Position')
        axes[head].set_ylabel('Query Position')
        
        # Add token labels if provided
        if input_tokens:
            axes[head].set_xticks(range(len(input_tokens)))
            axes[head].set_xticklabels(input_tokens, rotation=45)
            axes[head].set_yticks(range(len(input_tokens)))
            axes[head].set_yticklabels(input_tokens)
        
        plt.colorbar(im, ax=axes[head], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()

def create_sample_attention_pattern():
    """Create interpretable attention pattern for visualization."""
    tokens = ['What', 'is', 'machine', 'learning', '?']
    seq_len = len(tokens)
    num_heads = 4
    
    # Create different attention patterns for each head
    attention_patterns = torch.zeros(1, num_heads, seq_len, seq_len)
    
    # Head 1: Local attention (adjacent words)
    for i in range(seq_len):
        for j in range(max(0, i-1), min(seq_len, i+2)):
            attention_patterns[0, 0, i, j] = 0.5 if i != j else 0.3
    
    # Head 2: Query words attention
    question_words = [0, 4]  # "What" and "?"
    for i in range(seq_len):
        for j in question_words:
            attention_patterns[0, 1, i, j] = 0.4
        attention_patterns[0, 1, i, i] = 0.2
    
    # Head 3: Content words attention
    content_words = [2, 3]  # "machine", "learning"
    for i in range(seq_len):
        for j in content_words:
            attention_patterns[0, 2, i, j] = 0.4
        attention_patterns[0, 2, i, i] = 0.2
    
    # Head 4: Global attention (uniform)
    attention_patterns[0, 3, :, :] = 0.2
    
    # Normalize to sum to 1
    attention_patterns = F.softmax(attention_patterns, dim=-1)
    
    return attention_patterns, tokens

# Visualize sample attention patterns
sample_attention, tokens = create_sample_attention_pattern()
print("Sample Attention Patterns:")
print(f"Tokens: {tokens}")
visualize_attention_heads(sample_attention, tokens)

print("\nInterpretation:")
print("Head 1: Local attention (focuses on adjacent words)")
print("Head 2: Question structure (focuses on 'What' and '?')")
print("Head 3: Content focus (focuses on 'machine' and 'learning')")
print("Head 4: Global context (uniform attention)")

## 5. Practical Exercise {#5-exercise}

**Exercise**: Implement and experiment with attention mechanisms

### Tasks:
1. Compare additive vs multiplicative attention
2. Implement attention in seq2seq decoder
3. Visualize attention alignments
4. Experiment with different numbers of heads

### Questions:
1. How does attention help with long sequences?
2. What do different attention heads capture?
3. When might attention patterns be problematic?

### Extensions:
- Self-attention for encoder
- Cross-attention between sequences
- Attention dropout and regularization
- Positional encoding effects

In [None]:
# Exercise: Attention analysis
def analyze_attention_patterns():
    """Analyze different attention mechanisms."""
    
    print("=== Attention Analysis ===")
    
    # Create test sequences
    seq_lengths = [5, 10, 20]
    hidden_dim = 128
    
    additive_attn = AdditiveAttention(hidden_dim)
    
    for seq_len in seq_lengths:
        # Test data
        decoder_hidden = torch.randn(1, hidden_dim)
        encoder_outputs = torch.randn(1, seq_len, hidden_dim)
        
        # Compute attention
        context, weights = additive_attn(decoder_hidden, encoder_outputs)
        
        # Analyze attention distribution
        entropy = -torch.sum(weights * torch.log(weights + 1e-8), dim=1)
        max_attention = torch.max(weights, dim=1)[0]
        
        print(f"\nSequence length: {seq_len}")
        print(f"Attention entropy: {entropy.item():.3f} (higher = more uniform)")
        print(f"Max attention weight: {max_attention.item():.3f}")
        print(f"Attention distribution: {weights[0].detach().numpy()[:5]}...")

def compare_attention_types():
    """Compare different attention mechanisms."""
    
    print("\n=== Attention Type Comparison ===")
    
    hidden_dim = 256
    seq_len = 8
    
    # Test inputs
    decoder_hidden = torch.randn(1, hidden_dim)
    encoder_outputs = torch.randn(1, seq_len, hidden_dim)
    
    # Additive attention
    additive_attn = AdditiveAttention(hidden_dim)
    context_add, weights_add = additive_attn(decoder_hidden, encoder_outputs)
    
    # Scaled dot-product attention
    dot_attn = ScaledDotProductAttention(hidden_dim)
    query = decoder_hidden.unsqueeze(1)  # (1, 1, hidden_dim)
    context_dot, weights_dot = dot_attn(query, encoder_outputs, encoder_outputs)
    
    print(f"Additive attention:")
    print(f"  Parameters: {count_parameters(additive_attn)['total']:,}")
    print(f"  Attention weights: {weights_add[0][:4].detach().numpy()}...")
    
    print(f"\nScaled dot-product attention:")
    print(f"  Parameters: {count_parameters(dot_attn)['total']:,}")
    print(f"  Attention weights: {weights_dot[0, 0, :4].detach().numpy()}...")

# Run analysis
analyze_attention_patterns()
compare_attention_types()

print("\n=== Attention Mechanisms Complete ===")
print("Key Concepts Learned:")
print("• Attention solves information bottleneck in seq2seq")
print("• Additive vs multiplicative attention mechanisms")
print("• Multi-head attention for diverse representations")
print("• Attention visualization and interpretation")
print("• Foundation for transformer architectures")
print("\nNext: Full transformer implementation!")