# Understanding Attention Mechanisms with JAX and Flax

This notebook provides a deep dive into attention mechanisms, a crucial component of transformer models. We'll cover:
1. Self-attention implementation
2. Multi-head attention
3. Visualization of attention patterns

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import matplotlib.pyplot as plt

# Set random seed for reproducibility
key = jax.random.PRNGKey(0)

## 1. Simple Self-Attention

Let's implement a basic self-attention mechanism from scratch.

In [None]:
def simple_self_attention(query, key, value):
    """Basic self-attention implementation.
    
    Args:
        query: Query vectors [batch_size, seq_len, d_model]
        key: Key vectors [batch_size, seq_len, d_model]
        value: Value vectors [batch_size, seq_len, d_model]
        
    Returns:
        Attention output and attention weights
    """
    # Compute attention scores
    d_k = query.shape[-1]
    scores = jnp.matmul(query, key.transpose(0, 2, 1)) / jnp.sqrt(d_k)
    
    # Apply softmax
    attention_weights = jax.nn.softmax(scores, axis=-1)
    
    # Compute output
    output = jnp.matmul(attention_weights, value)
    
    return output, attention_weights

# Test the implementation
batch_size, seq_len, d_model = 2, 4, 8
query = jax.random.normal(key, (batch_size, seq_len, d_model))
key = jax.random.normal(key, (batch_size, seq_len, d_model))
value = jax.random.normal(key, (batch_size, seq_len, d_model))

output, weights = simple_self_attention(query, key, value)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)

## 2. Multi-Head Attention using Flax

Now let's implement multi-head attention using Flax's module system.

In [None]:
class MultiHeadAttention(nn.Module):
    num_heads: int
    d_model: int
    
    def setup(self):
        # Head dimension
        self.d_k = self.d_model // self.num_heads
        
        # Linear projections
        self.q_proj = nn.Dense(self.d_model)
        self.k_proj = nn.Dense(self.d_model)
        self.v_proj = nn.Dense(self.d_model)
        self.output_proj = nn.Dense(self.d_model)
    
    def __call__(self, x):
        batch_size = x.shape[0]
        
        # Linear projections and reshape for multiple heads
        q = self.q_proj(x).reshape(batch_size, -1, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        k = self.k_proj(x).reshape(batch_size, -1, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        v = self.v_proj(x).reshape(batch_size, -1, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        
        # Scaled dot-product attention
        scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(self.d_k)
        attention_weights = jax.nn.softmax(scores, axis=-1)
        attention_output = jnp.matmul(attention_weights, v)
        
        # Reshape and project output
        output = attention_output.transpose(0, 2, 1, 3).reshape(batch_size, -1, self.d_model)
        return self.output_proj(output), attention_weights

# Initialize and test the multi-head attention
mha = MultiHeadAttention(num_heads=4, d_model=64)
params = mha.init(key, jnp.ones((2, 8, 64)))
x = jax.random.normal(key, (2, 8, 64))
output, weights = mha.apply(params, x)

print("Multi-head attention output shape:", output.shape)
print("Multi-head attention weights shape:", weights.shape)

## 3. Visualizing Attention Patterns

Let's create a function to visualize attention patterns.

In [None]:
def plot_attention(attention_weights, title="Attention Weights"):
    """Plot attention weights as a heatmap."""
    plt.figure(figsize=(10, 8))
    plt.imshow(attention_weights, cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.xlabel('Key position')
    plt.ylabel('Query position')
    plt.show()

# Create some example attention patterns
seq_len = 10
x = jnp.ones((1, seq_len, 64))
output, weights = mha.apply(params, x)

# Plot attention weights for the first head
plot_attention(weights[0, 0], "Attention Pattern (Head 0)")

## 4. Masked Self-Attention

Implementation of masked self-attention, useful for autoregressive models.

In [None]:
def create_causal_mask(seq_len):
    """Create a causal mask for masked self-attention."""
    return jnp.triu(jnp.ones((seq_len, seq_len)), k=1) * -1e9

def masked_self_attention(query, key, value):
    """Self-attention with causal masking."""
    d_k = query.shape[-1]
    scores = jnp.matmul(query, key.transpose(0, 2, 1)) / jnp.sqrt(d_k)
    
    # Apply causal mask
    mask = create_causal_mask(query.shape[1])
    scores = scores + mask
    
    attention_weights = jax.nn.softmax(scores, axis=-1)
    output = jnp.matmul(attention_weights, value)
    
    return output, attention_weights

# Test masked attention
output, weights = masked_self_attention(query, key, value)
plot_attention(weights[0], "Masked Attention Pattern")