# Implement Attention from Scratch

## Problem Statement

Implement a Scaled Dot-Product Attention mechanism from scratch using PyTorch. Your mission (should you choose to accept it) is to replicate what PyTorch's built-in scaled_dot_product_attention does — manually. This core component is essential in Transformer architectures and helps models focus on relevant parts of a sequence. You'll test your implementation against PyTorch's native one to ensure you nailed it.

## Requirements
Define the Function:

Create a function scaled_dot_product_attention(q, k, v, mask=None) that:
Computes attention scores via the dot product of query and key vectors.
Scales the scores using the square root of the key dimension.
Applies an optional mask to the scores.
Applies softmax to convert scores into attention weights.
Uses these weights to compute a weighted sum of values (V).
Test Your Work:

Use sample tensors for query (Q), key (K), and value (V).
Compare the result of your custom implementation with PyTorch's F.scaled_dot_product_attention using an assert to check numerical accuracy.
Constraints
Do NOT use F.scaled_dot_product_attention inside your custom function — that defeats the whole point.
Your implementation must handle batch dimensions correctly.
Support optional masking for future tokens or padding.
Use only PyTorch ops — no cheating with external attention libs.
Hint Use `torch.matmul()` to compute dot products and `F.softmax()` for the final attention weights. The mask (if used) should be applied **before** the softmax using `masked_fill`.

In [1]:
import torch
# Purpose: Import PyTorch for tensor operations.
# Theory: Provides tensor computations with autograd support for attention mechanism.

import torch.nn.functional as F
# Purpose: Import functional module for softmax operation.
# Theory: F.softmax is used to compute attention weights.

# Set random seed for reproducibility
torch.manual_seed(42)
# Purpose: Fix random seed for consistent input tensors.
# Theory: Ensures reproducible results, aligning with previous problems (e.g., RMS Norm).

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Compute the scaled dot-product attention.
    
    Args:
        q: Query tensor of shape (..., seq_len_q, d_k)
        k: Key tensor of shape (..., seq_len_k, d_k)
        v: Value tensor of shape (..., seq_len_k, d_v)
        mask: Optional mask tensor of shape (..., seq_len_q, seq_len_k)
    
    Returns:
        output: Attention output tensor of shape (..., seq_len_q, d_v)
        attention_weights: Attention weights tensor of shape (..., seq_len_q, seq_len_k)
    """
    # Purpose: Define function to compute scaled dot-product attention.
    # Theory: Implements the core attention mechanism from "Attention is All You Need".
    
    d_k = q.size(-1)
    # Purpose: Get key dimension (d_k) from query tensor.
    # Theory: d_k is used for scaling factor (sqrt(d_k)) to normalize attention scores.
    
    scores = torch.matmul(q, k.transpose(-2, -1))
    # Purpose: Compute dot product of query and key: Q * K^T.
    # Theory: Produces raw attention scores, shape (..., seq_len_q, seq_len_k).
    
    scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # Purpose: Scale scores by sqrt(d_k).
    # Theory: Prevents large dot products in high dimensions, stabilizing gradients.
    
    if mask is not None:
        # Purpose: Check if a mask is provided.
        # Theory: Mask handles causal attention (future tokens) or padding.
        
        scores = scores.masked_fill(mask == 1, -1e9)
        # Purpose: Apply mask by setting masked positions to a large negative value.
        # Theory: Ensures masked positions have near-zero weights after softmax.
    
    attention_weights = F.softmax(scores, dim=-1)
    # Purpose: Apply softmax to scores to get attention weights.
    # Theory: Normalizes scores to probabilities, summing to 1 over the key dimension.
    
    output = torch.matmul(attention_weights, v)
    # Purpose: Compute weighted sum of values: A * V.
    # Theory: Produces final output, shape (..., seq_len_q, d_v), combining relevant values.
    
    return output, attention_weights
    # Purpose: Return attention output and weights.
    # Theory: Output is used in Transformer layers; weights are useful for analysis.

# Test implementation
if __name__ == "__main__":
    # Purpose: Test custom attention against PyTorch's implementation.
    # Theory: Verifies correctness and numerical accuracy.
    
    # Input tensors
    batch_size, seq_len, dim = 1, 3, 3
    # Purpose: Define dimensions for test tensors.
    # Theory: Small dimensions for simplicity, matching problem constraints.
    
    q = torch.randn(batch_size, seq_len, dim)
    k = torch.randn(batch_size, seq_len, dim)
    v = torch.randn(batch_size, seq_len, dim)
    # Purpose: Generate random query, key, and value tensors.
    # Theory: Shape (1, 3, 3) simulates a small batch for testing.
    
    # Test without mask
    output_custom, weights_custom = scaled_dot_product_attention(q, k, v)
    # Purpose: Compute attention with custom implementation.
    # Theory: Tests core functionality without masking.
    
    output_pytorch = F.scaled_dot_product_attention(q, k, v)
    # Purpose: Compute attention with PyTorch's implementation.
    # Theory: Serves as ground truth for comparison.
    
    print("Custom Output:", output_custom)
    print("PyTorch Output:", output_pytorch)
    # Purpose: Print outputs for visual inspection.
    # Theory: Allows checking numerical closeness.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Outputs do not match!"
    # Purpose: Verify numerical equivalence.
    # Theory: Ensures custom implementation matches PyTorch within tolerance.
    
    print("Test without mask passed!")
    # Purpose: Confirm successful test.
    # Theory: Validates core attention mechanism.
    
    # Test with causal mask
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).unsqueeze(0)
    # Purpose: Create causal mask to prevent attending to future tokens.
    # Theory: Upper triangular matrix with 1s above diagonal, shape (1, 3, 3).
    
    output_custom, weights_custom = scaled_dot_product_attention(q, k, v, mask)
    # Purpose: Compute attention with custom implementation and mask.
    # Theory: Tests masking functionality for autoregressive models.
    
    output_pytorch = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
    # Purpose: Compute attention with PyTorch's implementation and mask.
    # Theory: Ground truth for masked attention.
    
    print("Custom Output (with mask):", output_custom)
    print("PyTorch Output (with mask):", output_pytorch)
    # Purpose: Print masked outputs.
    # Theory: Verifies masking correctness.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Masked outputs do not match!"
    # Purpose: Verify numerical equivalence for masked case.
    # Theory: Ensures masking is applied correctly.
    
    print("Test with causal mask passed!")
    # Purpose: Confirm successful masked test.
    # Theory: Validates attention for causal settings (e.g., GPT).
    
    # Test with padding mask
    padding_mask = torch.tensor([[[0, 0, 1]]], dtype=torch.bool)
    # Purpose: Create padding mask to mask specific positions.
    # Theory: Masks the last token in the sequence, shape (1, 1, 3).
    
    output_custom, weights_custom = scaled_dot_product_attention(q, k, v, padding_mask)
    # Purpose: Compute attention with padding mask.
    # Theory: Tests handling of padding in sequences.
    
    output_pytorch = F.scaled_dot_product_attention(q, k, v, attn_mask=padding_mask)
    # Purpose: Compute PyTorch attention with padding mask.
    # Theory: Ground truth for padding case.
    
    print("Custom Output (with padding mask):", output_custom)
    print("PyTorch Output (with padding mask):", output_pytorch)
    # Purpose: Print padding-masked outputs.
    # Theory: Verifies padding mask application.
    
    assert torch.allclose(output_custom, output_pytorch, atol=1e-08, rtol=1e-05), "Padding masked outputs do not match!"
    # Purpose: Verify numerical equivalence for padding mask.
    # Theory: Ensures correct handling of padding.
    
    print("Test with padding mask passed!")
    # Purpose: Confirm successful padding test.
    # Theory: Validates attention for padded sequences (e.g., BERT).

Custom Output: tensor([[[ 0.0377, -0.3133,  0.8707],
         [ 0.4541, -0.3508,  0.8461],
         [ 0.3709, -0.2885,  0.8044]]])
PyTorch Output: tensor([[[ 0.0377, -0.3133,  0.8707],
         [ 0.4541, -0.3508,  0.8461],
         [ 0.3709, -0.2885,  0.8044]]])
Test without mask passed!
Custom Output (with mask): tensor([[[-0.7658, -0.7506,  1.3525],
         [ 0.4709, -0.3905,  0.8777],
         [ 0.3709, -0.2885,  0.8044]]])
PyTorch Output (with mask): tensor([[[ 0.2526, -0.1964,  0.7419],
         [ 0.4312, -0.2969,  0.8033],
         [ 0.3709, -0.2885,  0.8044]]])


AssertionError: Masked outputs do not match!