# Implement Multi-Head Attention from Scratch

## Description: Implement Multi-Head Attention (MHA), a core component of Transformer models, as described in “Attention is All You Need” (Vaswani et al., 2017). The function multi_head_attention(q, k, v, num_heads, d_model, mask=None) projects query ($   Q   $), key ($   K   $), and value ($   V   $) tensors into multiple attention heads, applies scaled dot-product attention per head, concatenates the outputs, and applies a final linear projection. The implementation must match PyTorch’s torch.nn.MultiheadAttention in output accuracy (within tolerance atol=1e-08, rtol=1e-05), handle batch dimensions, support masking, and use only PyTorch operations. The provided code fails the assertion due to mismatched weight initialization and masking logic. We’ll fix this by initializing weights to match PyTorch’s module and ensuring proper mask broadcasting.

## Mathematical Definition:

Inputs:

Query: $   Q \in \mathbb{R}^{N \times L_q \times d_{\text{model}}}   $, where $   N   $ is batch size, $   L_q   $ is query sequence length, $   d_{\text{model}}   $ is embedding dimension.
Key: $   K \in \mathbb{R}^{N \times L_k \times d_{\text{model}}}   $, where $   L_k   $ is key sequence length.
Value: $   V \in \mathbb{R}^{N \times L_k \times d_{\text{model}}}   $.
Mask: Optional tensor $   M \in \mathbb{R}^{N \times L_q \times L_k}   $ (or broadcastable) with 0s (valid) or 1s (masked).
$   \text{num\_heads}   $: Number of attention heads.
$   d_{\text{head}} = d_{\text{model}} / \text{num\_heads}   $: Dimension per head.


Linear Projections:

Project inputs: $   Q_h = Q W_i^Q   $, $   K_h = K W_i^K   $, $   V_h = V W_i^V   $, where $   W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{head}}}   $ for each head $   i   $.
Combine projections across heads: Shape becomes $   (N, \text{num\_heads}, L_q, d_{\text{head}})   $ for $   Q_h   $.


Scaled Dot-Product Attention (per head):

Scores: $   S = Q_h K_h^T / \sqrt{d_{\text{head}}}   $, where $   S \in \mathbb{R}^{N \times \text{num\_heads} \times L_q \times L_k}   $.
Mask: Set $   S_{i,j} = -\infty   $ where $   M_{i,j} = 1   $.
Weights: $   A = \text{softmax}(S, \text{dim}=-1)   $.
Output: $   O_h = A V_h   $, where $   O_h \in \mathbb{R}^{N \times \text{num\_heads} \times L_q \times d_{\text{head}}}   $.


Concatenation and Final Projection:

Concatenate heads: $   O = \text{Concat}(O_1, \ldots, O_{\text{num\_heads}})   $, shape $   (N, L_q, d_{\text{model}})   $.
Final projection: $   O_{\text{final}} = O W^O   $, where $   W^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}   $.


Output:

Return $   O_{\text{final}} \in \mathbb{R}^{N \times L_q \times d_{\text{model}}}   $.


Validation:

Compare with torch.nn.MultiheadAttention using torch.allclose.



Requirements:

Implement multi_head_attention(q, k, v, num_heads, d_model, mask=None):

Inputs: $   Q, K, V   $ of shape $  (N, L_q, d_{\text{model}})  $, $  (N, L_k, d_{\text{model}})  $, $  (N, L_k, d_{\text{model}})  $, number of heads, model dimension, and optional mask.
Output: Attention output of shape $  (N, L_q, d_{\text{model}})  $.
Support batch processing and masking (causal or padding).


Use PyTorch operations (torch.matmul, F.softmax, nn.Linear).
Test with synthetic tensors ($   N=3, L_q=L_k=4, d_{\text{model}}=8, \text{num\_heads}=2   $).
Match PyTorch’s torch.nn.MultiheadAttention output.
Provide detailed Purpose and Theory comments.
Fix the provided code’s assertion failure by aligning weight initialization and mask handling.

Constraints:

Use only PyTorch operations (no transformers or external libraries).
Ensure $   d_{\text{model}}   $ is divisible by num_heads.
Support batch-first format ($   (N, L_q, d_{\text{model}})   $).
Handle optional masking for causal or padding scenarios.
Match PyTorch’s output within atol=1e-08, rtol=1e-05.

Synthetic Dataset:

Inputs:

$   Q, K, V   $: Random tensors of shape $  (3, 4, 8)  $, generated with torch.rand and seed 42.
$   \text{num\_heads} = 2   $, $   d_{\text{model}} = 8   $, so $   d_{\text{head}} = 4   $.
Mask: Test with no mask, causal mask, and padding mask.


Test Cases:

Without mask.
Causal mask (upper triangular) to prevent attending to future tokens.
Padding mask to ignore specific positions.




In [1]:
import torch
# Purpose: Import PyTorch for tensor operations.
# Theory: Provides tensor computations with autograd support for attention.

import torch.nn as nn
# Purpose: Import neural network modules for linear projections.
# Theory: nn.Linear is used for Q, K, V projections and output transformation.

import torch.nn.functional as F
# Purpose: Import functional module for softmax.
# Theory: F.softmax computes attention weights over the key dimension.

# Set random seed for reproducibility
torch.manual_seed(42)
# Purpose: Fix random seed for consistent input tensors and weight initialization.
# Theory: Ensures reproducible results, aligning with previous problems (e.g., Scaled Dot-Product Attention).

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Compute scaled dot-product attention (from previous problem).
    
    Args:
        q: Query tensor of shape (..., seq_len_q, d_k)
        k: Key tensor of shape (..., seq_len_k, d_k)
        v: Value tensor of shape (..., seq_len_k, d_v)
        mask: Optional mask tensor of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        output: Attention output tensor of shape (..., seq_len_q, d_v)
        attention_weights: Attention weights tensor of shape (..., seq_len_q, seq_len_k)
    """
    # Purpose: Define scaled dot-product attention for use in multi-head attention.
    # Theory: Computes attention scores, applies mask, and weights values.
    
    d_k = q.size(-1)
    # Purpose: Get key dimension for scaling.
    # Theory: Used to scale scores by sqrt(d_k) for numerical stability.
    
    scores = torch.matmul(q, k.transpose(-2, -1))
    # Purpose: Compute dot product: Q * K^T.
    # Theory: Produces raw scores, shape (..., seq_len_q, seq_len_k).
    
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # Purpose: Scale scores by sqrt(d_k).
    # Theory: Prevents large dot products, stabilizing gradients.
    
    if mask is not None:
        # Purpose: Apply optional mask.
        # Theory: Masks future tokens or padding, setting scores to -inf.
        
        scores = scores.masked_fill(mask == 1, -1e9)
        # Purpose: Set masked positions to large negative value.
        # Theory: Ensures near-zero weights after softmax.
    
    attention_weights = F.softmax(scores, dim=-1)
    # Purpose: Convert scores to probabilities.
    # Theory: Softmax over key dimension ensures weights sum to 1.
    
    output = torch.matmul(attention_weights, v)
    # Purpose: Compute weighted sum: A * V.
    # Theory: Produces output, shape (..., seq_len_q, d_v).
    
    return output, attention_weights
    # Purpose: Return attention output and weights.
    # Theory: Output for further processing, weights for analysis.

def multi_head_attention(q, k, v, num_heads, d_model, mask=None):
    """
    Implements multi-head attention.
    
    Args:
        q: Query tensor of shape (batch_size, seq_len, d_model)
        k: Key tensor of shape (batch_size, seq_len, d_model)
        v: Value tensor of shape (batch_size, seq_len, d_model)
        num_heads: Number of attention heads
        d_model: Total embedding dimension
        mask: Optional mask tensor for attention
    
    Returns:
        Tensor: Multi-head attention output of shape (batch_size, seq_len, d_model)
    """
    # Purpose: Define multi-head attention function.
    # Theory: Projects inputs, applies attention per head, concatenates, and projects output.
    
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
    # Purpose: Ensure valid head dimension.
    # Theory: d_head = d_model / num_heads must be an integer.
    
    batch_size, seq_len, _ = q.shape
    # Purpose: Get input dimensions.
    # Theory: batch_size and seq_len are used for reshaping tensors.
    
    d_head = d_model // num_heads
    # Purpose: Compute dimension per head.
    # Theory: Each head processes a subspace of d_model.
    
    # Initialize linear layers
    q_linear = nn.Linear(d_model, d_model, bias=False)
    k_linear = nn.Linear(d_model, d_model, bias=False)
    v_linear = nn.Linear(d_model, d_model, bias=False)
    out_linear = nn.Linear(d_model, d_model, bias=False)
    # Purpose: Define linear projections for Q, K, V, and output.
    # Theory: Projects inputs to head-specific subspaces and combines outputs.
    
    # Project inputs
    Q = q_linear(q)
    K = k_linear(k)
    V = v_linear(v)
    # Purpose: Apply linear projections to Q, K, V.
    # Theory: Shape (batch_size, seq_len, d_model), preparing for head splitting.
    
    # Reshape for multi-head: (batch_size, seq_len, num_heads, d_head) -> (batch_size, num_heads, seq_len, d_head)
    Q = Q.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    K = K.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    V = V.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    # Purpose: Split projections into heads and reorder dimensions.
    # Theory: Shape (batch_size, num_heads, seq_len, d_head) for parallel attention.
    
    # Adjust mask for multi-head
    if mask is not None:
        if mask.dim() == 3:
            mask = mask.unsqueeze(1)  # (batch_size, 1, seq_len, seq_len)
        # Purpose: Ensure mask is compatible with multi-head shape.
        # Theory: Broadcasts mask to (batch_size, num_heads, seq_len, seq_len).
    
    # Apply scaled dot-product attention
    output, _ = scaled_dot_product_attention(Q, K, V, mask)
    # Purpose: Compute attention for each head.
    # Theory: Output shape (batch_size, num_heads, seq_len, d_head).
    
    # Reshape output: (batch_size, num_heads, seq_len, d_head) -> (batch_size, seq_len, d_model)
    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    # Purpose: Concatenate heads and flatten to original shape.
    # Theory: Combines head outputs, shape (batch_size, seq_len, d_model).
    
    # Final projection
    output = out_linear(output)
    # Purpose: Apply final linear transformation.
    # Theory: Restores output to (batch_size, seq_len, d_model).
    
    return output
    # Purpose: Return final attention output.
    # Theory: Matches input shape for use in Transformer layers.

# Test implementation
if __name__ == "__main__":
    # Purpose: Test multi-head attention against PyTorch's implementation.
    # Theory: Verifies numerical accuracy and mask handling.
    
    # Input tensors
    batch_size, seq_len, d_model, num_heads = 3, 4, 8, 2
    # Purpose: Define dimensions for test tensors.
    # Theory: Matches problem constraints for synthetic data.
    
    q = torch.rand(batch_size, seq_len, d_model)
    k = torch.rand(batch_size, seq_len, d_model)
    v = torch.rand(batch_size, seq_len, d_model)
    # Purpose: Generate random input tensors.
    # Theory: Shape (3, 4, 8) for batch processing.
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Purpose: Set device for computation.
    # Theory: Ensures compatibility with GPU if available.
    
    q, k, v = q.to(device), k.to(device), v.to(device)
    # Purpose: Move tensors to device.
    # Theory: Ensures computations run on the same device.
    
    # Initialize PyTorch's MultiheadAttention
    multihead_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True).to(device)
    # Purpose: Create reference model for comparison.
    # Theory: PyTorch’s implementation is the ground truth.
    
    # Copy weights to custom linear layers
    custom_q_linear = nn.Linear(d_model, d_model, bias=False).to(device)
    custom_k_linear = nn.Linear(d_model, d_model, bias=False).to(device)
    custom_v_linear = nn.Linear(d_model, d_model, bias=False).to(device)
    custom_out_linear = nn.Linear(d_model, d_model, bias=False).to(device)
    # Purpose: Initialize custom linear layers.
    # Theory: Must match PyTorch’s weights to pass assertion.
    
    # Access PyTorch's internal weights
    with torch.no_grad():
        # PyTorch combines Q, K, V projections into one matrix
        in_proj_weight = multihead_attn.in_proj_weight
        # Split into Q, K, V weights
        q_weight = in_proj_weight[:d_model, :]
        k_weight = in_proj_weight[d_model:2*d_model, :]
        v_weight = in_proj_weight[2*d_model:, :]
        # Assign to custom layers
        custom_q_linear.weight.copy_(q_weight)
        custom_k_linear.weight.copy_(k_weight)
        custom_v_linear.weight.copy_(v_weight)
        custom_out_linear.weight.copy_(multihead_attn.out_proj.weight)
    # Purpose: Copy weights from PyTorch’s module to custom layers.
    # Theory: Ensures identical projections to match output.
    
    # Redefine multi_head_attention with custom weights
    def multi_head_attention(q, k, v, num_heads, d_model, mask=None):
        # Purpose: Redefine function with custom weights for testing.
        # Theory: Same logic as above, using initialized weights.
        
        assert d_model % num_heads == 0
        batch_size, seq_len, _ = q.shape
        d_head = d_model // num_heads
        
        Q = custom_q_linear(q)
        K = custom_k_linear(k)
        V = custom_v_linear(v)
        
        Q = Q.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
        K = K.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
        V = V.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
        
        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
        
        output, _ = scaled_dot_product_attention(Q, K, V, mask)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return custom_out_linear(output)
    
    # Test without mask
    output_custom = multi_head_attention(q, k, v, num_heads, d_model)
    output_pytorch, _ = multihead_attn(q, k, v)
    # Purpose: Compute outputs for comparison.
    # Theory: Tests core multi-head attention functionality.
    
    print("Custom Output:", output_custom)
    print("PyTorch Output:", output_pytorch)
    # Purpose: Print outputs for inspection.
    # Theory: Allows visual confirmation of similarity.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Outputs do not match!"
    # Purpose: Verify numerical equivalence.
    # Theory: Ensures custom implementation matches PyTorch.
    
    print("Test without mask passed!")
    # Purpose: Confirm successful test.
    # Theory: Validates core functionality.
    
    # Test with causal mask
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0).expand(batch_size, -1, -1).to(device)
    # Purpose: Create causal mask for autoregressive attention.
    # Theory: Shape (batch_size, seq_len, seq_len), masks future tokens.
    
    output_custom = multi_head_attention(q, k, v, num_heads, d_model, mask)
    output_pytorch, _ = multihead_attn(q, k, v, attn_mask=mask)
    # Purpose: Compute masked outputs.
    # Theory: Tests causal mask handling.
    
    print("Custom Output (with causal mask):", output_custom)
    print("PyTorch Output (with causal mask):", output_pytorch)
    # Purpose: Print masked outputs.
    # Theory: Verifies mask application.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Causal masked outputs do not match!"
    # Purpose: Verify masked output equivalence.
    # Theory: Ensures correct causal attention.
    
    print("Test with causal mask passed!")
    # Purpose: Confirm successful masked test.
    # Theory: Validates autoregressive attention (e.g., GPT).
    
    # Test with padding mask
    padding_mask = torch.tensor([[[0, 0, 0, 1],
                                 [0, 0, 0, 1],
                                 [0, 0, 0, 1],
                                 [0, 0, 0, 1]],
                                [[0, 0, 0, 0],
                                 [0, 0, 0, 0],
                                 [0, 0, 0, 0],
                                 [0, 0, 0, 0]],
                                [[0, 0, 1, 1],
                                 [0, 0, 1, 1],
                                 [0, 0, 1, 1],
                                 [0, 0, 1, 1]]], dtype=torch.bool).to(device)
    # Purpose: Create padding mask for specific positions.
    # Theory: Shape (batch_size, seq_len, seq_len), masks padded tokens.
    
    output_custom = multi_head_attention(q, k, v, num_heads, d_model, padding_mask)
    output_pytorch, _ = multihead_attn(q, k, v, attn_mask=padding_mask)
    # Purpose: Compute padding-masked outputs.
    # Theory: Tests padding mask handling.
    
    print("Custom Output (with padding mask):", output_custom)
    print("PyTorch Output (with padding mask):", output_pytorch)
    # Purpose: Print padding-masked outputs.
    # Theory: Verifies padding mask application.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Padding masked outputs do not match!"
    # Purpose: Verify padding-masked output equivalence.
    # Theory: Ensures correct handling of padded sequences (e.g., BERT).
    
    print("Test with padding mask passed!")
    # Purpose: Confirm successful padding test.
    # Theory: Validates attention for padded sequences.

Custom Output: tensor([[[ 7.1700e-02,  5.3411e-02,  3.3034e-01,  1.5892e-01,  6.5924e-02,
           2.3432e-01,  1.2263e-01,  2.1005e-01],
         [ 7.5689e-02,  5.2544e-02,  3.3219e-01,  1.5638e-01,  6.8600e-02,
           2.3832e-01,  1.2223e-01,  2.1144e-01],
         [ 7.2534e-02,  5.3125e-02,  3.3081e-01,  1.5864e-01,  6.6714e-02,
           2.3542e-01,  1.2263e-01,  2.1010e-01],
         [ 6.9490e-02,  5.1792e-02,  3.3143e-01,  1.6153e-01,  6.6810e-02,
           2.3345e-01,  1.2220e-01,  2.0941e-01]],

        [[-4.5193e-02,  2.6360e-02,  2.6671e-01,  2.7284e-01,  9.3438e-02,
           1.5845e-01,  1.2320e-01,  1.5793e-01],
         [-4.3427e-02,  2.6275e-02,  2.6286e-01,  2.7077e-01,  9.2779e-02,
           1.5760e-01,  1.2343e-01,  1.5535e-01],
         [-4.5840e-02,  2.6753e-02,  2.6641e-01,  2.7329e-01,  9.4146e-02,
           1.5864e-01,  1.2388e-01,  1.5618e-01],
         [-4.1810e-02,  2.5378e-02,  2.6161e-01,  2.6910e-01,  9.1674e-02,
           1.5723e-01,  1.2250e-0

RuntimeError: The shape of the 3D attn_mask is torch.Size([3, 4, 4]), but should be (6, 4, 4).