## Attention

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# attention
# Scaled Dot-Product Attention implementation
# It's fundermental for multi-head attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.d_k = d_k
        
        # Linear projections for Q, K, V
        # Q: Query
        # K: Index to calculate relavent score between Q and V 
        # V: The actual content we want to match with
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_v)
        
    def forward(self, q, k, v, mask=None):
        # Linear projections
        q = self.W_q(q)  # (batch_size, seq_len, d_k)
        k = self.W_k(k)  # (batch_size, seq_len, d_k)
        v = self.W_v(v)  # (batch_size, seq_len, d_v)
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_k)  # (batch_size, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
        output = torch.matmul(attention_weights, v)  # (batch_size, seq_len, d_v)
        
        return output, attention_weights

# Q&A
# Q: Why k.transpose(-2, -1)?
# A: We use transpose(-2, -1) to allows us to compute: QK^T to calculate attention score between Q and K

# Q: What is attention score?
# A: Attention score is the score of each position to other positions. You can think of it as the similarity between the query and the key.

# Q: Why divide by np.sqrt(self.d_k)?
# A: Divide by np.sqrt(self.d_k): avoid large values in dot products which lead to small gradients in softmax

# Q: Why divide by np.sqrt(self.d_k), not self.d_k or self.d_k^2?
# A: mathematical explanation: assume Q and K are random variables with mean 0 and variance 1, 
#   then the dot product of Q and K has mean 0 and variance d_k
#   so we divide by np.sqrt(self.d_k) to normalize the variance to 1 again

# Q: Why softmax?
# A: Softmax is used to convert the attention scores into a probability distribution.
#    It makes all outputs sum to 1 and each output is between 0 and 1.
#    This is perfect for attention weights because We want to know "how much attention" (proportion) to pay to each position.
#    Sigmoid not used because output sum != 1.
#    ReLU not used because it does not normalize outputs and create a probability distribution.
#    Tanh not used because it does not create a probability distribution and negative values does not make sense for attention weights.


In [3]:
def test_attention():
    # Create sample input
    batch_size = 2
    seq_len = 3
    d_model = 4
    d_k = 2
    d_v = 2
    
    # Create random input tensors
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize attention layer
    attention = ScaledDotProductAttention(d_model, d_k, d_v)
    
    # Forward pass
    output, attention_weights = attention(x, x, x)
    
    print("Input shape:", x.shape)
    print("Output shape:", output.shape)
    print("Attention weights shape:", attention_weights.shape)
    
    return output, attention_weights

# Test the implementation
output, attention_weights = test_attention()

Input shape: torch.Size([2, 3, 4])
Output shape: torch.Size([2, 3, 2])
Attention weights shape: torch.Size([2, 3, 3])
