# Understanding Attention Mechanism in JAX

<!--* freshness: { reviewed: '2024-12-08' } *-->

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/attention_mechanism_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/attention_mechanism_in_jax.ipynb)

**Copyright 2024 The JAX Authors.**

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)

This tutorial provides a comprehensive guide to implementing the **Attention Mechanism** from scratch using JAX. The attention mechanism is the core building block of modern deep learning architectures like Transformers, which power models such as GPT, BERT, and LLaMA.

By the end of this tutorial, you will understand:
- The mathematical foundation of attention
- How to implement Scaled Dot-Product Attention
- How to build Multi-Head Attention using `jax.vmap`
- How to create Positional Encodings
- How to combine everything into a simple Transformer block

**Prerequisites:**
- Basic understanding of JAX (`jit`, `grad`, `vmap`)
- Familiarity with neural networks and linear algebra
- Python and NumPy experience

## Setup

Let's start by importing the necessary libraries.

In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import numpy as np

# For visualization
import matplotlib.pyplot as plt

# Check JAX version
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")

## Part 1: What is Attention?

### The Intuition

Imagine you're reading a sentence: *"The cat sat on the mat because it was tired."*

When you process the word "it", your brain automatically focuses on "cat" to understand what "it" refers to. This selective focus is what the **attention mechanism** models computationally.

### The Key Concepts: Query, Key, Value

The attention mechanism works with three vectors:

- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "What information do I provide?"

The attention score between a query and a key determines how much the value should contribute to the output.

### Scaled Dot-Product Attention

The formula for scaled dot-product attention is:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $Q$ is the query matrix of shape `(seq_len, d_k)`
- $K$ is the key matrix of shape `(seq_len, d_k)`
- $V$ is the value matrix of shape `(seq_len, d_v)`
- $d_k$ is the dimension of keys (used for scaling)

## Part 2: Implementing Scaled Dot-Product Attention

Let's implement the attention mechanism step by step.

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        query: Array of shape (..., seq_len_q, d_k)
        key: Array of shape (..., seq_len_k, d_k)
        value: Array of shape (..., seq_len_k, d_v)
        mask: Optional boolean mask of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        output: Weighted sum of values, shape (..., seq_len_q, d_v)
        attention_weights: Attention weights, shape (..., seq_len_q, seq_len_k)
    """
    # Get the dimension of keys for scaling
    d_k = query.shape[-1]
    
    # Step 1: Compute attention scores (QK^T)
    # query: (..., seq_len_q, d_k)
    # key.T: (..., d_k, seq_len_k)
    # scores: (..., seq_len_q, seq_len_k)
    scores = jnp.matmul(query, jnp.swapaxes(key, -2, -1))
    
    # Step 2: Scale by sqrt(d_k) to prevent softmax saturation
    scaled_scores = scores / jnp.sqrt(d_k)
    
    # Step 3: Apply mask if provided (for decoder self-attention)
    if mask is not None:
        # Replace masked positions with very negative values
        scaled_scores = jnp.where(mask, scaled_scores, -1e9)
    
    # Step 4: Apply softmax to get attention weights
    attention_weights = jax.nn.softmax(scaled_scores, axis=-1)
    
    # Step 5: Multiply weights by values
    output = jnp.matmul(attention_weights, value)
    
    return output, attention_weights

### Testing Scaled Dot-Product Attention

Let's create some sample data and test our implementation.

In [None]:
# Initialize random key
key = random.key(42)

# Create sample Q, K, V matrices
seq_len = 4
d_k = 8  # dimension of keys
d_v = 8  # dimension of values

# Split key for random generation
key, q_key, k_key, v_key = random.split(key, 4)

# Random Q, K, V matrices
Q = random.normal(q_key, (seq_len, d_k))
K = random.normal(k_key, (seq_len, d_k))
V = random.normal(v_key, (seq_len, d_v))

print(f"Query shape: {Q.shape}")
print(f"Key shape: {K.shape}")
print(f"Value shape: {V.shape}")

# Compute attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)

print(f"\nOutput shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

# Verify attention weights sum to 1
print(f"\nAttention weights sum (per query): {attention_weights.sum(axis=-1)}")

### Visualizing Attention Weights

Attention weights show which positions the model "attends to" for each query.

In [None]:
def plot_attention_weights(attention_weights, title="Attention Weights"):
    """
    Visualize attention weights as a heatmap.
    """
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(attention_weights, cmap='viridis', aspect='auto')
    
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Query Position')
    ax.set_title(title)
    
    # Add colorbar
    plt.colorbar(im, ax=ax)
    
    # Add text annotations
    for i in range(attention_weights.shape[0]):
        for j in range(attention_weights.shape[1]):
            text = ax.text(j, i, f'{attention_weights[i, j]:.2f}',
                          ha='center', va='center', color='white', fontsize=9)
    
    plt.tight_layout()
    plt.show()

# Visualize our attention weights
plot_attention_weights(np.array(attention_weights), "Scaled Dot-Product Attention")

## Part 3: Multi-Head Attention

### Why Multiple Heads?

A single attention head can only focus on one type of relationship. **Multi-Head Attention** runs multiple attention operations in parallel, allowing the model to:

1. Attend to different positions
2. Learn different types of relationships
3. Capture both local and global patterns

The formula is:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

Where:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

In [None]:
def init_multihead_attention_params(key, d_model, num_heads):
    """
    Initialize parameters for Multi-Head Attention.
    
    Args:
        key: JAX random key
        d_model: Model dimension
        num_heads: Number of attention heads
    
    Returns:
        Dictionary of parameters
    """
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    d_k = d_model // num_heads
    
    # Split key for each weight matrix
    keys = random.split(key, 4)
    
    # Initialize weight matrices with Xavier/Glorot initialization
    scale = jnp.sqrt(2.0 / (d_model + d_k))
    
    params = {
        'W_q': random.normal(keys[0], (d_model, d_model)) * scale,
        'W_k': random.normal(keys[1], (d_model, d_model)) * scale,
        'W_v': random.normal(keys[2], (d_model, d_model)) * scale,
        'W_o': random.normal(keys[3], (d_model, d_model)) * scale,
    }
    
    return params


def multihead_attention(params, query, key, value, num_heads, mask=None):
    """
    Compute Multi-Head Attention.
    
    Args:
        params: Dictionary containing W_q, W_k, W_v, W_o
        query: Input query of shape (batch_size, seq_len, d_model)
        key: Input key of shape (batch_size, seq_len, d_model)
        value: Input value of shape (batch_size, seq_len, d_model)
        num_heads: Number of attention heads
        mask: Optional attention mask
    
    Returns:
        output: Multi-head attention output
        attention_weights: Attention weights from all heads
    """
    batch_size = query.shape[0]
    seq_len = query.shape[1]
    d_model = query.shape[2]
    d_k = d_model // num_heads
    
    # Step 1: Linear projections
    Q = jnp.matmul(query, params['W_q'])  # (batch, seq_len, d_model)
    K = jnp.matmul(key, params['W_k'])
    V = jnp.matmul(value, params['W_v'])
    
    # Step 2: Reshape for multi-head attention
    # (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
    Q = Q.reshape(batch_size, seq_len, num_heads, d_k).transpose(0, 2, 1, 3)
    K = K.reshape(batch_size, seq_len, num_heads, d_k).transpose(0, 2, 1, 3)
    V = V.reshape(batch_size, seq_len, num_heads, d_k).transpose(0, 2, 1, 3)
    
    # Step 3: Apply attention to all heads in parallel
    # Q, K, V are now (batch, num_heads, seq_len, d_k)
    attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
    
    # Step 4: Concatenate heads
    # (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
    attention_output = attention_output.transpose(0, 2, 1, 3).reshape(
        batch_size, seq_len, d_model
    )
    
    # Step 5: Final linear projection
    output = jnp.matmul(attention_output, params['W_o'])
    
    return output, attention_weights

In [None]:
# Test Multi-Head Attention
batch_size = 2
seq_len = 6
d_model = 64
num_heads = 8

# Initialize parameters
key = random.key(0)
key, params_key, input_key = random.split(key, 3)

mha_params = init_multihead_attention_params(params_key, d_model, num_heads)

# Create random input
x = random.normal(input_key, (batch_size, seq_len, d_model))

# Apply multi-head attention (self-attention: Q=K=V)
output, attn_weights = multihead_attention(mha_params, x, x, x, num_heads)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  -> (batch_size, num_heads, seq_len_q, seq_len_k)")

## Part 4: Positional Encoding

### Why Positional Encoding?

Unlike RNNs, attention is **position-invariant** â€” it treats the sequence as a set, not an ordered list. To incorporate position information, we add **positional encodings** to the input embeddings.

### Sinusoidal Positional Encoding

The original Transformer uses sinusoidal functions:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

This allows the model to learn to attend to relative positions because for any fixed offset $k$, $PE_{pos+k}$ can be represented as a linear function of $PE_{pos}$.

In [None]:
def sinusoidal_positional_encoding(seq_len, d_model):
    """
    Create sinusoidal positional encoding.
    
    Args:
        seq_len: Length of the sequence
        d_model: Dimension of the model
    
    Returns:
        Positional encoding matrix of shape (seq_len, d_model)
    """
    # Create position indices: (seq_len, 1)
    position = jnp.arange(seq_len)[:, jnp.newaxis]
    
    # Create dimension indices: (d_model/2,)
    div_term = jnp.exp(
        jnp.arange(0, d_model, 2) * (-jnp.log(10000.0) / d_model)
    )
    
    # Compute sin and cos
    pe = jnp.zeros((seq_len, d_model))
    pe = pe.at[:, 0::2].set(jnp.sin(position * div_term))
    pe = pe.at[:, 1::2].set(jnp.cos(position * div_term))
    
    return pe

# Create and visualize positional encoding
seq_len = 50
d_model = 64
pe = sinusoidal_positional_encoding(seq_len, d_model)

print(f"Positional encoding shape: {pe.shape}")

# Visualize
plt.figure(figsize=(12, 4))
plt.imshow(np.array(pe).T, cmap='RdBu', aspect='auto')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Sinusoidal Positional Encoding')
plt.colorbar()
plt.tight_layout()
plt.show()

## Part 5: Transformer Block

Now let's combine everything into a complete **Transformer Encoder Block**.

A Transformer block consists of:
1. Multi-Head Self-Attention
2. Add & Normalize (Residual connection + Layer Normalization)
3. Feed-Forward Network (FFN)
4. Add & Normalize

### Layer Normalization

Layer normalization normalizes across the feature dimension:

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

In [None]:
def layer_norm(x, gamma, beta, eps=1e-6):
    """
    Apply Layer Normalization.
    
    Args:
        x: Input tensor of shape (..., d_model)
        gamma: Scale parameter of shape (d_model,)
        beta: Shift parameter of shape (d_model,)
        eps: Small constant for numerical stability
    
    Returns:
        Normalized tensor of same shape as x
    """
    mean = jnp.mean(x, axis=-1, keepdims=True)
    variance = jnp.var(x, axis=-1, keepdims=True)
    normalized = (x - mean) / jnp.sqrt(variance + eps)
    return gamma * normalized + beta


def feed_forward_network(x, params):
    """
    Feed-Forward Network with ReLU activation.
    
    FFN(x) = max(0, xW1 + b1)W2 + b2
    
    Args:
        x: Input tensor of shape (..., d_model)
        params: Dictionary with W1, b1, W2, b2
    
    Returns:
        Output tensor of shape (..., d_model)
    """
    # First linear layer + ReLU
    hidden = jax.nn.relu(jnp.matmul(x, params['W1']) + params['b1'])
    # Second linear layer
    output = jnp.matmul(hidden, params['W2']) + params['b2']
    return output

In [None]:
def init_transformer_block_params(key, d_model, num_heads, d_ff):
    """
    Initialize all parameters for a Transformer block.
    
    Args:
        key: JAX random key
        d_model: Model dimension
        num_heads: Number of attention heads
        d_ff: Dimension of feed-forward hidden layer
    
    Returns:
        Dictionary of all parameters
    """
    keys = random.split(key, 6)
    
    # Multi-head attention parameters
    mha_params = init_multihead_attention_params(keys[0], d_model, num_heads)
    
    # Layer norm parameters
    ln1_gamma = jnp.ones(d_model)
    ln1_beta = jnp.zeros(d_model)
    ln2_gamma = jnp.ones(d_model)
    ln2_beta = jnp.zeros(d_model)
    
    # Feed-forward parameters
    scale1 = jnp.sqrt(2.0 / (d_model + d_ff))
    scale2 = jnp.sqrt(2.0 / (d_ff + d_model))
    
    ffn_params = {
        'W1': random.normal(keys[1], (d_model, d_ff)) * scale1,
        'b1': jnp.zeros(d_ff),
        'W2': random.normal(keys[2], (d_ff, d_model)) * scale2,
        'b2': jnp.zeros(d_model),
    }
    
    return {
        'mha': mha_params,
        'ln1_gamma': ln1_gamma,
        'ln1_beta': ln1_beta,
        'ln2_gamma': ln2_gamma,
        'ln2_beta': ln2_beta,
        'ffn': ffn_params,
    }


def transformer_block(params, x, num_heads, mask=None, dropout_key=None, dropout_rate=0.1):
    """
    Apply a single Transformer encoder block.
    
    Args:
        params: Block parameters
        x: Input tensor of shape (batch_size, seq_len, d_model)
        num_heads: Number of attention heads
        mask: Optional attention mask
        dropout_key: Optional key for dropout
        dropout_rate: Dropout probability
    
    Returns:
        Output tensor of same shape as input
    """
    # 1. Multi-Head Self-Attention
    attn_output, attn_weights = multihead_attention(
        params['mha'], x, x, x, num_heads, mask
    )
    
    # 2. Add & Normalize (first residual connection)
    x = layer_norm(
        x + attn_output,
        params['ln1_gamma'],
        params['ln1_beta']
    )
    
    # 3. Feed-Forward Network
    ffn_output = feed_forward_network(x, params['ffn'])
    
    # 4. Add & Normalize (second residual connection)
    output = layer_norm(
        x + ffn_output,
        params['ln2_gamma'],
        params['ln2_beta']
    )
    
    return output, attn_weights

In [None]:
# Test the complete Transformer block
batch_size = 2
seq_len = 10
d_model = 64
num_heads = 8
d_ff = 256  # Feed-forward hidden dimension (usually 4x d_model)

# Initialize
key = random.key(42)
key, params_key, input_key = random.split(key, 3)

block_params = init_transformer_block_params(params_key, d_model, num_heads, d_ff)

# Create input with positional encoding
x = random.normal(input_key, (batch_size, seq_len, d_model))
pe = sinusoidal_positional_encoding(seq_len, d_model)
x = x + pe  # Add positional encoding

# Apply transformer block
output, attn_weights = transformer_block(block_params, x, num_heads)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nTransformer block maintains the input shape!")

## Part 6: Training Example - Sequence Classification

Let's put it all together and train a simple Transformer for sequence classification.

We'll create a toy task: classify whether the sum of a sequence is above or below average.

In [None]:
def init_classifier_params(key, vocab_size, d_model, num_heads, d_ff, num_classes):
    """
    Initialize a simple Transformer classifier.
    """
    keys = random.split(key, 3)
    
    # Embedding layer
    embedding = random.normal(keys[0], (vocab_size, d_model)) * 0.02
    
    # Transformer block
    block_params = init_transformer_block_params(keys[1], d_model, num_heads, d_ff)
    
    # Classification head
    classifier = {
        'W': random.normal(keys[2], (d_model, num_classes)) * 0.02,
        'b': jnp.zeros(num_classes)
    }
    
    return {
        'embedding': embedding,
        'transformer': block_params,
        'classifier': classifier
    }


def forward(params, x, num_heads):
    """
    Forward pass of the classifier.
    
    Args:
        params: Model parameters
        x: Input token indices of shape (batch_size, seq_len)
        num_heads: Number of attention heads
    
    Returns:
        Logits of shape (batch_size, num_classes)
    """
    batch_size, seq_len = x.shape
    d_model = params['embedding'].shape[1]
    
    # Embed tokens
    embedded = params['embedding'][x]  # (batch, seq_len, d_model)
    
    # Add positional encoding
    pe = sinusoidal_positional_encoding(seq_len, d_model)
    embedded = embedded + pe
    
    # Apply transformer block
    transformed, _ = transformer_block(params['transformer'], embedded, num_heads)
    
    # Global average pooling
    pooled = jnp.mean(transformed, axis=1)  # (batch, d_model)
    
    # Classify
    logits = jnp.matmul(pooled, params['classifier']['W']) + params['classifier']['b']
    
    return logits

In [None]:
# Hyperparameters
vocab_size = 100
d_model = 32
num_heads = 4
d_ff = 128
num_classes = 2
seq_len = 8
batch_size = 32
learning_rate = 0.001
num_epochs = 100

# Initialize model
key = random.key(42)
key, init_key = random.split(key)
params = init_classifier_params(init_key, vocab_size, d_model, num_heads, d_ff, num_classes)

# Create toy dataset: classify if sum of sequence is above/below threshold
def generate_batch(key, batch_size, seq_len, vocab_size):
    """Generate a batch of sequences with labels."""
    x = random.randint(key, (batch_size, seq_len), 0, vocab_size)
    sums = x.sum(axis=1)
    threshold = vocab_size * seq_len / 2
    labels = (sums > threshold).astype(jnp.int32)
    return x, labels

# Loss function
def cross_entropy_loss(params, x, labels, num_heads):
    logits = forward(params, x, num_heads)
    one_hot = jax.nn.one_hot(labels, num_classes)
    log_probs = jax.nn.log_softmax(logits)
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

# Gradient function
@jit
def update(params, x, labels, learning_rate):
    loss, grads = jax.value_and_grad(cross_entropy_loss)(params, x, labels, num_heads)
    # Simple SGD update
    params = jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)
    return params, loss

# Accuracy function
@jit
def compute_accuracy(params, x, labels):
    logits = forward(params, x, num_heads)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == labels)

# Training loop
print("Training Transformer Classifier...")
print("=" * 50)

losses = []
accuracies = []

for epoch in range(num_epochs):
    key, batch_key = random.split(key)
    x_batch, y_batch = generate_batch(batch_key, batch_size, seq_len, vocab_size)
    
    params, loss = update(params, x_batch, y_batch, learning_rate)
    acc = compute_accuracy(params, x_batch, y_batch)
    
    losses.append(float(loss))
    accuracies.append(float(acc))
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch + 1:3d} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

print("=" * 50)
print("Training complete!")

In [None]:
# Plot training progress
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(losses)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.grid(True, alpha=0.3)

ax2.plot(accuracies)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy')
ax2.set_ylim([0, 1])
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Conclusion

In this tutorial, we've learned:

1. **Scaled Dot-Product Attention**: The fundamental building block that computes weighted combinations based on query-key similarities.

2. **Multi-Head Attention**: Running multiple attention operations in parallel to capture different types of relationships.

3. **Positional Encoding**: Adding position information to the otherwise position-invariant attention mechanism.

4. **Transformer Block**: Combining attention with feed-forward networks and residual connections.

5. **Training**: Putting it all together to train a simple classifier.

### Key JAX Features Used

- `jax.numpy` for NumPy-like array operations
- `jax.nn.softmax` for softmax activation
- `jax.random` for reproducible random operations
- `jax.value_and_grad` for efficient gradient computation
- `jax.jit` for JIT compilation
- `jax.tree.map` for applying functions to nested structures

### Next Steps

- Add decoder with masked self-attention
- Implement cross-attention for encoder-decoder architectures
- Explore advanced attention variants (Flash Attention, Linear Attention)
- Scale up with multiple Transformer layers

### Resources

- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) - The original Transformer paper
- [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/) - Visual guide
- [JAX Documentation](https://jax.readthedocs.io/) - Official JAX docs