# Attention Mechanism Explained with PyTorch

This notebook provides a simple explanation and implementation of the attention mechanism, specifically Scaled Dot-Product Attention, using PyTorch.

## Introduction

Attention mechanisms have become a fundamental component in many state-of-the-art deep learning models, particularly in Natural Language Processing (NLP) tasks like machine translation, text summarization, and question answering. 

The core idea behind attention is to allow the model to dynamically focus on different parts of the input sequence when producing an output. Instead of relying solely on the final hidden state of an encoder (like in traditional sequence-to-sequence models), attention allows the decoder (or subsequent layers) to "look back" at the entire input sequence and assign different weights (attention scores) to different input parts based on their relevance to the current output step.

## Core Idea: Query, Key, Value

Attention can be described in terms of three components:

1.  **Query (Q):** Represents the current context or state trying to retrieve information. In a sequence-to-sequence model's decoder, this might be the hidden state of the decoder at the current time step.
2.  **Key (K):** Paired with values. Queries are compared against keys to determine attention weights. In a sequence-to-sequence model, keys often correspond to the hidden states of the encoder for each input token.
3.  **Value (V):** The actual information associated with the keys. Once attention weights are calculated by comparing the query to the keys, these weights are used to create a weighted sum of the values. Values often correspond to the same source as keys (e.g., encoder hidden states).

The goal is: given a query, compute a weighted sum of the values, where the weight assigned to each value is determined by the compatibility (similarity) of the query with its corresponding key.

## Scaled Dot-Product Attention

This is one of the most common and effective attention mechanisms, popularized by the "Attention Is All You Need" paper (Transformer model).

The formula is:
```
Attention(Q, K, V) = softmax( (Q * K^T) / sqrt(d_k) ) * V
```
Where:
*   `Q` is the matrix of queries.
*   `K` is the matrix of keys.
*   `V` is the matrix of values.
*   `K^T` is the transpose of the key matrix.
*   `d_k` is the dimension of the keys (and queries).
*   `sqrt(d_k)` is the scaling factor used to prevent the dot products from becoming too large, which could push the softmax function into regions with very small gradients.
*   `softmax` is applied row-wise to the scaled scores to obtain attention weights that sum to 1.

Let's implement this step-by-step.

In [None]:
import torch
import torch.nn.functional as F
import math

In [None]:
# Dummy data
seq_len = 5    # Length of the input sequence
embed_dim = 8  # Dimension of embeddings / hidden states (d_k)
batch_size = 1 # Number of sequences processed in parallel

# Create Query, Key, Value tensors
# Shape: (batch_size, sequence_length, embedding_dimension)
Q = torch.randn(batch_size, seq_len, embed_dim)
K = torch.randn(batch_size, seq_len, embed_dim)
V = torch.randn(batch_size, seq_len, embed_dim)

print("Query (Q):", Q.shape)
print("Key (K):", K.shape)
print("Value (V):", V.shape)

### Step 1 & 2: Calculate Scaled Scores

First, we compute the dot product between each query and all keys (`Q * K^T`). Then, we scale these scores by dividing by the square root of the key dimension (`d_k`).

In [None]:
# 1. Calculate dot products between Query and Key (transposed)
# We need K transposed so the dimensions match for matrix multiplication:
# Q: (batch, seq_len, embed_dim)
# K^T: (batch, embed_dim, seq_len)
# Result (scores): (batch, seq_len, seq_len)
# scores[b, i, j] represents the similarity between the i-th query and j-th key in batch b.
scores = torch.matmul(Q, K.transpose(-2, -1))
print("Raw Scores (Q * K^T):", scores.shape)
# print(scores)

# 2. Scale the scores
# Divide by the square root of the key dimension (d_k)
d_k = K.size(-1) # embed_dim
scaled_scores = scores / math.sqrt(d_k)
print("\nScaled Scores:", scaled_scores.shape)
# print(scaled_scores)

### Step 3: Apply Softmax

Apply the softmax function to the scaled scores along the key dimension (the last dimension of the `scaled_scores` tensor). This converts the scores into probabilities (attention weights) that sum to 1 for each query.

In [None]:
# 3. Apply Softmax to get attention weights
# Softmax is applied across the last dimension (keys)
# Result shape: (batch_size, seq_len, seq_len)
# attention_weights[b, i, j] is the weight given to the j-th value vector when computing the output for the i-th query.
attention_weights = F.softmax(scaled_scores, dim=-1)
print("Attention Weights (Softmax):", attention_weights.shape)
# print(attention_weights)
# print("Sum of weights for first query:", attention_weights[0, 0, :].sum()) # Should be close to 1

### Step 4: Multiply Weights by Values

Finally, multiply the attention weights by the Value matrix (`V`). This produces the context vector (the output of the attention layer), which is a weighted sum of the values, where the weights are determined by the query-key similarities.

In [None]:
# 4. Multiply weights by Value
# attention_weights: (batch, seq_len, seq_len)
# V: (batch, seq_len, embed_dim)
# Result (context_vector): (batch, seq_len, embed_dim)
# context_vector[b, i, :] is the output for the i-th query, computed as a weighted sum of all value vectors.
context_vector = torch.matmul(attention_weights, V)
print("Context Vector (Weights * V):", context_vector.shape)
# print(context_vector)

## Summary

The `context_vector` is the output of the Scaled Dot-Product Attention mechanism. It has the same shape as the Query and Value inputs. Each vector in the `context_vector` sequence (e.g., `context_vector[0, i, :]`) represents the information aggregated from the entire Value sequence (`V`), weighted according to how relevant each part of the input (represented by `K`) was to the corresponding query (`Q[0, i, :]`).

This allows the model to focus on the most pertinent parts of the input sequence when generating each part of the output sequence.

In [None]:
# Example: Putting it all together in a function

def scaled_dot_product_attention(Q, K, V, mask=None):
    """Calculate scaled dot product attention.
    
    Args:
        Q (torch.Tensor): Queries. Shape: (batch_size, ..., seq_len_q, d_k).
        K (torch.Tensor): Keys. Shape: (batch_size, ..., seq_len_k, d_k).
        V (torch.Tensor): Values. Shape: (batch_size, ..., seq_len_v, d_v).
                          Note: seq_len_k == seq_len_v
        mask (torch.Tensor, optional): Mask to apply before softmax. 
                                     Shape: (batch_size, ..., seq_len_q, seq_len_k).
                                     Defaults to None.
                                     
    Returns:
        torch.Tensor: Context vector. Shape: (batch_size, ..., seq_len_q, d_v).
        torch.Tensor: Attention weights. Shape: (batch_size, ..., seq_len_q, seq_len_k).
    """
    d_k = K.size(-1)
    # Matmul Q and K^T: (..., seq_len_q, d_k) x (..., d_k, seq_len_k) -> (..., seq_len_q, seq_len_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    if mask is not None:
        # Apply mask (typically -infinity for positions to ignore)
        scores = scores.masked_fill(mask == 0, -1e9) 
        
    # Apply softmax: (..., seq_len_q, seq_len_k)
    attention_weights = F.softmax(scores, dim=-1)
    
    # Matmul weights and V: (..., seq_len_q, seq_len_k) x (..., seq_len_v, d_v) -> (..., seq_len_q, d_v)
    # Note: seq_len_k == seq_len_v
    context = torch.matmul(attention_weights, V)
    
    return context, attention_weights

# --- Test the function ---
context_output, weights_output = scaled_dot_product_attention(Q, K, V)

print("\n--- Function Output ---")
print("Context Vector Shape:", context_output.shape)
print("Attention Weights Shape:", weights_output.shape)