In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

## 1. Basic Self-Attention

Self-attention computes attention scores between all positions in a sequence.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor (batch_size, seq_len, d_k)
        K: Key tensor (batch_size, seq_len, d_k)
        V: Value tensor (batch_size, seq_len, d_v)
        mask: Optional mask tensor
    
    Returns:
        output: Attention output (batch_size, seq_len, d_v)
        attention_weights: Attention weights (batch_size, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    
    # Compute attention scores: Q @ K^T / sqrt(d_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Apply mask if provided (useful for causal/decoder attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Apply attention weights to values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

### Test Basic Self-Attention

In [None]:
# Create sample input
batch_size = 1
seq_len = 4
d_model = 8

# Random embeddings for a sequence of 4 tokens
x = torch.randn(batch_size, seq_len, d_model)

# For self-attention, Q, K, V are all projections of the same input
Q = K = V = x

# Compute attention
output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights (how much each position attends to others):")
print(attn_weights[0].detach().numpy())

## 2. Visualize Attention Weights

In [None]:
def plot_attention_weights(attention_weights, tokens=None):
    """
    Visualize attention weight matrix as a heatmap.
    """
    plt.figure(figsize=(8, 6))
    
    # Convert to numpy and squeeze batch dimension
    weights = attention_weights.squeeze(0).detach().numpy()
    
    plt.imshow(weights, cmap='viridis', aspect='auto')
    plt.colorbar(label='Attention Weight')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.title('Attention Weights Heatmap')
    
    if tokens:
        plt.xticks(range(len(tokens)), tokens)
        plt.yticks(range(len(tokens)), tokens)
    
    plt.tight_layout()
    plt.show()

# Visualize the attention pattern
plot_attention_weights(attn_weights, tokens=['Token1', 'Token2', 'Token3', 'Token4'])

## 3. Multi-Head Attention

Multi-head attention runs multiple attention operations in parallel, allowing the model to attend to different aspects of the input.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        Multi-Head Attention module.
        
        Args:
            d_model: Dimension of the model (embedding size)
            num_heads: Number of attention heads
        """
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)."""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine heads back into (batch_size, seq_len, d_model)."""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear projections
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch_size, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Compute attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        
        # Combine heads
        output = self.combine_heads(output)
        
        # Final linear projection
        output = self.W_o(output)
        
        return output, attention_weights

### Test Multi-Head Attention

In [None]:
# Initialize multi-head attention
d_model = 64
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)

# Create sample input
x = torch.randn(batch_size, seq_len, d_model)

# For self-attention, Q=K=V
output, attn_weights = mha(x, x, x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape (per head): {attn_weights.shape}")
print(f"Number of attention heads: {num_heads}")
print(f"Dimension per head (d_k): {d_model // num_heads}")

## 4. Compare with PyTorch Built-in Multi-Head Attention

In [None]:
# PyTorch's built-in multi-head attention
pytorch_mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)

# Test it
x = torch.randn(batch_size, seq_len, d_model)
output_pytorch, attn_weights_pytorch = pytorch_mha(x, x, x)

print(f"PyTorch MHA Output shape: {output_pytorch.shape}")
print(f"PyTorch MHA Attention weights shape: {attn_weights_pytorch.shape}")
print("\nBoth implementations produce the same output shapes!")

## 5. Causal (Masked) Attention

In decoder-only models like GPT, we need causal masking so each position can only attend to earlier positions.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask to prevent attending to future positions.
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return ~mask  # Invert: 1 means attend, 0 means mask

# Create and visualize causal mask
seq_len = 6
causal_mask = create_causal_mask(seq_len)

plt.figure(figsize=(6, 5))
plt.imshow(causal_mask.numpy(), cmap='gray', aspect='auto')
plt.title('Causal Attention Mask\n(White = Can Attend, Black = Masked)')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.colorbar()
plt.tight_layout()
plt.show()

print("Causal mask (1 = attend, 0 = mask):")
print(causal_mask.int())

## 6. Test Causal Attention

In [None]:
# Test with causal mask
seq_len = 6
d_model = 8
x = torch.randn(1, seq_len, d_model)

# Create causal mask
mask = create_causal_mask(seq_len).unsqueeze(0)  # Add batch dimension

# Compute attention with mask
output, attn_weights = scaled_dot_product_attention(x, x, x, mask=mask)

print("Attention weights with causal masking:")
print(attn_weights[0].detach().numpy())
print("\nNotice: Each row sums to 1, but only attends to current and previous positions!")

# Visualize
plot_attention_weights(attn_weights, tokens=[f'T{i}' for i in range(seq_len)])