## Attention

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [5]:
# attention
# Scaled Dot-Product Attention implementation
# It's fundermental for multi-head attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.d_k = d_k
        
        # Linear projections for Q, K, V
        # Q: Query
        # K: Index to calculate relavent score between Q and V 
        # V: The actual content we want to match with
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_v)
        
    def forward(self, q, k, v, mask=None):
        # Linear projections
        q = self.W_q(q)  # (batch_size, seq_len, d_k)
        k = self.W_k(k)  # (batch_size, seq_len, d_k)
        v = self.W_v(v)  # (batch_size, seq_len, d_v)
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.d_k)  # (batch_size, seq_len, seq_len)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
        output = torch.matmul(attention_weights, v)  # (batch_size, seq_len, d_v)
        
        return output, attention_weights

# Q&A
# Q: Why k.transpose(-2, -1)?
# A: We use transpose(-2, -1) to allows us to compute: QK^T to calculate attention score between Q and K

# Q: What is attention score?
# A: Attention score is the score of each position to other positions. You can think of it as the similarity between the query and the key.

# Q: Why divide by np.sqrt(self.d_k)?
# A: Divide by np.sqrt(self.d_k): avoid large values in dot products which lead to small gradients in softmax

# Q: Why divide by np.sqrt(self.d_k), not self.d_k or self.d_k^2?
# A: mathematical explanation: assume Q and K are random variables with mean 0 and variance 1, 
#    then the dot product of Q and K has mean 0 and variance d_k
#    so we divide by np.sqrt(self.d_k) to normalize the variance to 1 again

# Q: Why softmax?
# A: Softmax is used to convert the attention scores into a probability distribution.
#    It makes all outputs sum to 1 and each output is between 0 and 1.
#    This is perfect for attention weights because We want to know "how much attention" (proportion) to pay to each position.
#    Sigmoid not used because output sum != 1.
#    ReLU not used because it does not normalize outputs and create a probability distribution.
#    Tanh not used because it does not create a probability distribution and negative values does not make sense for attention weights.

# Q: What is the output of the attention mechanism?
# A: 2 outputs, main output and attention weights.
#    Main output will be used for the next layer. It's a weighted sum of the values.
#    Each position gets a new representation based on what it attended to.
#    Attention weights are used to visualize the attention distribution.

In [None]:
def test_attention():
    # Create sample input
    batch_size = 64
    seq_len = 3
    d_model = 4
    d_k = 2
    d_v = 2
    
    # Create random input tensors
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize attention layer
    attention = ScaledDotProductAttention(d_model, d_k, d_v)
    
    # Forward pass
    output, attention_weights = attention(x, x, x)
    
    print("Input shape:", x.shape)
    print("Output shape:", output.shape)
    print("Attention weights shape:", attention_weights.shape)
    
    return output, attention_weights

# Test the implementation
output, attention_weights = test_attention()

Input shape: torch.Size([64, 3, 4])
Output shape: torch.Size([64, 3, 2])
Attention weights shape: torch.Size([64, 3, 3])


In [6]:
# Multi-Head Attention implementation
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # dimension of each head
        self.d_v = self.d_k  # typically d_v = d_k
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)  # project to d_model dimensions
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        # Scaled dot-product attention
        self.attention = ScaledDotProductAttention(self.d_k, self.d_k, self.d_v)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # Linear projections and reshape for multi-head
        q = self.W_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(v).view(batch_size, -1, self.num_heads, self.d_v).transpose(1, 2)
        
        # Apply attention to each head
        outputs = []
        attention_weights = []
        
        for h in range(self.num_heads):
            output, attn = self.attention(q[:, h], k[:, h], v[:, h], mask)
            outputs.append(output)
            attention_weights.append(attn)
        
        # Concatenate all heads
        output = torch.cat(outputs, dim=0).view(batch_size, -1, self.d_model)
        
        # Final linear projection
        output = self.W_o(output)
        
        # Average attention weights across heads
        attention_weights = torch.stack(attention_weights).mean(dim=0)
        
        return output, attention_weights

# Q&A
# Q: What is common value for d_model, d_k and d_v?
# A: d_model is the dimension of the input and output vectors throughout the transformer architecture. 
#    In the original paper, d_model = 512. In BERT, d_model = 768(BASE) or 1024(LARGE). In GPT-2, d_model = 768(SMALL) to 4096(XL).
#    d_k = d_v = d_model / h, where h is the number of heads. Use same dimension for query, key, and value.
#    In typical transformer model, d_model = 512, h = 8, so d_k = d_v = 512 / 8 = 64.

# Q: Why d_k = d_v = d_model / h?
# A: We choose d_k = d_v = d_model / h because it allows multiple heads to work in parallel, each focusing on different aspects of the input.

# Q: What's the purpose of multi-head?
# A: Each head learns a different aspect of the input, so it can attend to different parts of the input.
#    eg, head1 focuses on local dependencies, head2 focus on long-range dependencies, head3 focus on syntactic relationships, etc.
#    This allows the model to attend to different parts of the input and learn different patterns.

# Q: Why use linear layer for output projection
# A: It helps information integration across different heads.
#    The output projection adds learnable parameters that help the model learn how to best combine the information from different attention head.

# Q: Why use average attention weights across heads?
# A: We got an average aggregated attention weights across heads. It better represent the attention distribution.

In [7]:
# Test multi-head attention
def test_multi_head_attention():
    # Create sample input
    batch_size = 2
    seq_len = 3
    d_model = 8  # must be divisible by num_heads
    num_heads = 2  # number of heads
    
    # Create random input tensors
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Initialize multi-head attention layer
    mha = MultiHeadAttention(d_model, num_heads)
    
    # Forward pass
    output, attention_weights = mha(x, x, x)
    
    print("Input shape:", x.shape)
    print("Output shape:", output.shape)
    print("Attention weights shape:", attention_weights.shape)
    print("\nNumber of heads:", num_heads)
    print("Dimension per head:", d_model // num_heads)
    
    return output, attention_weights

# Test the implementation
output, attention_weights = test_multi_head_attention()

Input shape: torch.Size([2, 3, 8])
Output shape: torch.Size([2, 3, 8])
Attention weights shape: torch.Size([2, 3, 3])

Number of heads: 2
Dimension per head: 4
