# Attention Visualization

This notebook visualizes attention patterns and compares standard vs linear attention.

## Learning Objectives
- Visualize attention weight distributions
- Compare standard (softmax) vs linear attention patterns
- Understand how causal masking affects attention

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Set up matplotlib
plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## Standard Scaled Dot-Product Attention

First, let's implement and visualize standard attention:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def standard_attention(query, key, value, causal=False, return_weights=True):
    """
    Standard scaled dot-product attention.
    
    Args:
        query: (batch, heads, seq_len, head_dim)
        key: (batch, heads, seq_len, head_dim)
        value: (batch, heads, seq_len, head_dim)
        causal: Whether to apply causal masking
        return_weights: Whether to return attention weights
    
    Returns:
        output: (batch, heads, seq_len, head_dim)
        attention_weights: (batch, heads, seq_len, seq_len) if return_weights
    """
    d_k = query.size(-1)
    seq_len = query.size(-2)
    
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    
    # Apply causal mask if needed
    if causal:
        mask = torch.triu(torch.ones(seq_len, seq_len, device=query.device), diagonal=1)
        scores = scores.masked_fill(mask.bool(), float('-inf'))
    
    # Softmax
    attention_weights = F.softmax(scores, dim=-1)
    
    # Compute output
    output = torch.matmul(attention_weights, value)
    
    if return_weights:
        return output, attention_weights
    return output

## Linear Attention

Linear attention uses a feature map and reorders the computation:

$$\text{LinearAttention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\phi(K)^T \mathbf{1}}$$

In [None]:
def elu_plus_one(x):
    """ELU + 1 feature map for linear attention."""
    return F.elu(x) + 1

def linear_attention(query, key, value, causal=False, return_weights=True, eps=1e-6):
    """
    Linear attention with ELU+1 feature map.
    
    Args:
        query: (batch, heads, seq_len, head_dim)
        key: (batch, heads, seq_len, head_dim)
        value: (batch, heads, seq_len, head_dim)
        causal: Whether to apply causal masking
        return_weights: Whether to compute attention weights (expensive!)
    
    Returns:
        output: (batch, heads, seq_len, head_dim)
        attention_weights: (batch, heads, seq_len, seq_len) if return_weights
    """
    # Apply feature map
    q = elu_plus_one(query)
    k = elu_plus_one(key)
    
    if not causal:
        # Bidirectional: KV = K^T @ V, then Q @ KV
        kv = torch.einsum('bhnd,bhnv->bhdv', k, value)  # (B, H, D, D)
        k_sum = k.sum(dim=2)  # (B, H, D)
        
        output = torch.einsum('bhnd,bhdv->bhnv', q, kv)
        normalizer = torch.einsum('bhnd,bhd->bhn', q, k_sum).unsqueeze(-1)
        output = output / (normalizer + eps)
    else:
        # Causal: cumulative sum
        kv = torch.einsum('bhnd,bhnv->bhndv', k, value)
        kv_cumsum = torch.cumsum(kv, dim=2)
        k_cumsum = torch.cumsum(k, dim=2)
        
        output = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
        normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1)
        output = output / (normalizer + eps)
    
    if return_weights:
        # Compute "attention weights" for visualization
        # Note: These aren't used in the actual computation
        scores = torch.einsum('bhnd,bhmd->bhnm', q, k)
        if causal:
            seq_len = query.size(-2)
            mask = torch.triu(torch.ones(seq_len, seq_len, device=query.device), diagonal=1)
            scores = scores.masked_fill(mask.bool(), 0)
        weights = scores / (scores.sum(dim=-1, keepdim=True) + eps)
        return output, weights
    
    return output

## Visualizing Attention Patterns

Let's create some sample data and visualize the attention patterns.

In [None]:
# Create sample data
torch.manual_seed(42)
batch_size = 1
num_heads = 1
seq_len = 32
head_dim = 16

# Random queries, keys, values
q = torch.randn(batch_size, num_heads, seq_len, head_dim)
k = torch.randn(batch_size, num_heads, seq_len, head_dim)
v = torch.randn(batch_size, num_heads, seq_len, head_dim)

print(f"Q, K, V shapes: {q.shape}")

In [None]:
# Compute attention for both methods
_, standard_weights = standard_attention(q, k, v, causal=False)
_, linear_weights = linear_attention(q, k, v, causal=False)

# Also compute causal versions
_, standard_causal_weights = standard_attention(q, k, v, causal=True)
_, linear_causal_weights = linear_attention(q, k, v, causal=True)

In [None]:
def plot_attention(weights, title, ax):
    """Plot attention weights as a heatmap."""
    w = weights[0, 0].detach().numpy()
    im = ax.imshow(w, cmap='Blues', aspect='auto')
    ax.set_title(title)
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')
    return im

# Create comparison plot
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

plot_attention(standard_weights, 'Standard Attention (Bidirectional)', axes[0, 0])
plot_attention(linear_weights, 'Linear Attention (Bidirectional)', axes[0, 1])
plot_attention(standard_causal_weights, 'Standard Attention (Causal)', axes[1, 0])
plot_attention(linear_causal_weights, 'Linear Attention (Causal)', axes[1, 1])

plt.tight_layout()
plt.savefig('../results/plots/attention_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

## Attention Distribution Comparison

Let's look at how the attention distributions differ for a single query position.

In [None]:
# Compare distributions for a single query
query_pos = 15

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bidirectional
axes[0].bar(range(seq_len), standard_weights[0, 0, query_pos].detach().numpy(), 
            alpha=0.7, label='Standard (Softmax)')
axes[0].bar(range(seq_len), linear_weights[0, 0, query_pos].detach().numpy(), 
            alpha=0.7, label='Linear (ELU+1)')
axes[0].set_xlabel('Key Position')
axes[0].set_ylabel('Attention Weight')
axes[0].set_title(f'Attention Distribution for Query Position {query_pos} (Bidirectional)')
axes[0].legend()
axes[0].axvline(x=query_pos, color='red', linestyle='--', alpha=0.5, label='Query Position')

# Causal
axes[1].bar(range(seq_len), standard_causal_weights[0, 0, query_pos].detach().numpy(), 
            alpha=0.7, label='Standard (Softmax)')
axes[1].bar(range(seq_len), linear_causal_weights[0, 0, query_pos].detach().numpy(), 
            alpha=0.7, label='Linear (ELU+1)')
axes[1].set_xlabel('Key Position')
axes[1].set_ylabel('Attention Weight')
axes[1].set_title(f'Attention Distribution for Query Position {query_pos} (Causal)')
axes[1].legend()
axes[1].axvline(x=query_pos, color='red', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig('../results/plots/attention_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

## Key Observations

1. **Softmax attention is "peakier"**: Standard attention tends to concentrate on a few key positions
2. **Linear attention is smoother**: The ELU+1 feature map produces more diffuse attention
3. **Causal masking works in both**: The lower triangular pattern is preserved

This difference in attention patterns explains why linear attention may not be a drop-in replacement for all tasks.

## Entropy Analysis

Let's quantify how "spread out" the attention is using entropy.

In [None]:
def attention_entropy(weights, eps=1e-10):
    """Compute entropy of attention distributions."""
    # Clamp to avoid log(0)
    w = weights.clamp(min=eps)
    entropy = -torch.sum(w * torch.log(w), dim=-1)
    return entropy

# Compute entropy
standard_entropy = attention_entropy(standard_weights).mean().item()
linear_entropy = attention_entropy(linear_weights).mean().item()
max_entropy = np.log(seq_len)  # Uniform distribution

print(f"Standard Attention Entropy: {standard_entropy:.3f}")
print(f"Linear Attention Entropy: {linear_entropy:.3f}")
print(f"Maximum Possible Entropy (uniform): {max_entropy:.3f}")
print(f"\nLinear attention is {linear_entropy/standard_entropy:.1f}x more spread out")

## Next Steps

- [02_complexity_analysis.ipynb](02_complexity_analysis.ipynb): Analyze computational complexity
- [03_memory_profiling.ipynb](03_memory_profiling.ipynb): Profile memory usage
- [04_benchmark_comparison.ipynb](04_benchmark_comparison.ipynb): Full benchmark comparison