# Scaled Dot-Product Attention

An implementation of scaled dot-product attention with PyTorch, optionally accepting a mask for causal attention. This itself is not a learnable part but is a key computation used in Transformers' attention layers' MHA components.

## Code

In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, value)
        return output, attention_weights

## Line-by-Line Explanation

### def \_\_init\_\_(self, dropout=0.1)

- First, instantiate the parent class `nn.Module` 

- `self.dropout = nn.Dropout(dropout)`: Initialize a Dropout layer to apply dropout to attention weights, where we randomly drop some attention scores. 
  - Why: Adding dropout helps prevent overfitting and better generalize the model.

### def forward(self, query, key, value, mask=None)

This function implements the scaled dot-product attention formula below:

$\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V$

The input tensors:
- query, key, value are of the same shape: `(batch_size, seq_len, d_k)`
  - `d_k` is the dimensionality of the query and key vectors. In single-head attention implementations, it equals `d_model`, aka the embedding size. In multi-head attention, it equals to `d_model / num_head`
- mask shape is `(batch_size, seq_len, seq_len)`
  - it's a binary mask, 1 = attend, 0 = ignore. We use it to ignore future tokens or padding

Line-by-line:
- Get the scale dimension: 
  - `d_k = query.size(-1)`  
    - This is used for **scaling** the dot product later.
- Compute attention scores: 
  - `scores = torch.matmul(query, key.transpose(-2, -1))/torch.sqrt(torch.tensor(d_k, dtype=query.dtype))` 
    - First, compute the dot products of `query @ keyᵀ` which indicate how much attention one position should pay to another
    - Then, scale the dot products (divide by √d_k) to **scale down large values** and stablize softmax. This was recommended in the original Transformer paper
- If mask was provided in the input, apply it
  - `scores = scores.masked_fill(mask == 0, float('-inf'))` 
    - The mask is a binary mask with 1 for attend and 0 for ignore. Here we replace positions with 0 value with `-inf` in the mask, so that softmax scores them to zero, effectively paying zero attention to them.
- Convert attention scores to attention weights (probability distribution):
  - `attention_weights = F.softmax(scores, dim=-1)`  
    - All weights then fall into the range of 0 and 1 and sum up to 1.
- Apply dropout: 
  - `attention_weights = self.dropout(attention_weights)`  
    - Apply dropout to attention weights regularizes the model by randomly zeroing some weights. This prevents model from depending too heavily on any one token during trianing, thus improves model robustness and generalization.
- Calculate output, which is a weighted sum of the value vectors using the attention weights: 
  - `output = torch.matmul(attention_weights, value)` 
    - This combines information from all tokens (their values) according to how much attention should be paid to each of them
    - output shape is also `(batch_size, seq_len, d_k)`

The output tensors:
- output: context-aware information from all tokens. Shape: `(batch_size, seq_len, d_k)`
  - This is used as input to the next layer. It is also used in backpropagation during training.
- attention_weights: distribution of attention over all tokens. Shape: `(batch_size, seq_len, seq_len)` 
  - This shows attention distribution and can be used in visualization / interpretability analysis

Notes
- The order of applying mask, softmax, dropout is in the order 
  - mask -> softmax -> dropout
- Mask is applied before Softmax, because
  - 
- Softmax is applied before Dropout, because 
  - We want to randomly zero out token contributions (probabilities), not distort the probability distribution before calculating it from the similarity scores (softmax). 
  - This will result in the attention weights not sum up to 1 anymore but it is intentional during training. Re-normalize after dropout is also possible but not widely used. 
  - At inference, the dropout is disabled so that the token probability distribution sum up to 1. 
    - This is realized through using PyTorch's `nn.Dropout` module for dropout. Internally PyTorch tracks whether the model is in training `model.train()`or inference mode `model.eval()`. Dropout is then only applied for training.

## Test Code

In [9]:
def test_basic_functionality():
    """Test basic functionality of ScaledDotProductAttention"""
    print("=== Test 1: Basic Functionality ===")
    
    # Setup
    batch_size, seq_len, d_k = 2, 4, 8
    attention_layer = ScaledDotProductAttention(dropout=0.0)  # No dropout for deterministic test
    
    # Create test inputs
    query = torch.randn(batch_size, seq_len, d_k)
    key = torch.randn(batch_size, seq_len, d_k)
    value = torch.randn(batch_size, seq_len, d_k)
    
    # Forward pass
    output, attention_weights = attention_layer(query, key, value)
    
    # Test output shapes
    assert output.shape == (batch_size, seq_len, d_k), f"Expected output shape {(batch_size, seq_len, d_k)}, got {output.shape}"
    assert attention_weights.shape == (batch_size, seq_len, seq_len), f"Expected attention shape {(batch_size, seq_len, seq_len)}, got {attention_weights.shape}"
    
    # Test attention weights properties
    # 1. All weights should be non-negative
    assert torch.all(attention_weights >= 0), "Attention weights should be non-negative"
    
    # 2. Each row should sum to 1 (probability distribution)
    row_sums = torch.sum(attention_weights, dim=-1)
    assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-6), "Attention weights should sum to 1 along last dimension"
    
    print("✓ Basic functionality test passed!")
    print(f"  Output shape: {output.shape}")
    print(f"  Attention weights shape: {attention_weights.shape}")
    print(f"  Attention weights sum: {row_sums[0, 0].item():.6f} (should be ~1.0)")

test_basic_functionality()

=== Test 1: Basic Functionality ===
✓ Basic functionality test passed!
  Output shape: torch.Size([2, 4, 8])
  Attention weights shape: torch.Size([2, 4, 4])
  Attention weights sum: 1.000000 (should be ~1.0)
