# Basic Attention Mechanisms

This notebook introduces the fundamental concepts of attention mechanisms in deep learning. We'll explore:

1. The intuition behind attention
2. Basic attention computation
3. Implementation of additive and multiplicative attention
4. Visualizing attention weights

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Basic Attention Implementation

Let's implement a basic attention mechanism that can be used for both additive and multiplicative attention:

In [None]:
class BasicAttention(nn.Module):
    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        value_dim: int,
        attention_type: str = 'dot',
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.attention_type = attention_type
        self.query_dim = query_dim
        self.key_dim = key_dim
        self.value_dim = value_dim
        
        # For additive attention
        if attention_type == 'additive':
            self.attention = nn.Sequential(
                nn.Linear(query_dim + key_dim, query_dim),
                nn.Tanh(),
                nn.Linear(query_dim, 1)
            )
        
        # For multiplicative attention
        elif attention_type == 'dot':
            self.scale = torch.sqrt(torch.FloatTensor([query_dim]))
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            query: Query tensor of shape (batch_size, num_queries, query_dim)
            key: Key tensor of shape (batch_size, num_keys, key_dim)
            value: Value tensor of shape (batch_size, num_keys, value_dim)
            mask: Optional mask tensor of shape (batch_size, num_queries, num_keys)
            
        Returns:
            Tuple of (output, attention_weights)
        """
        batch_size = query.shape[0]
        
        if self.attention_type == 'additive':
            # Additive attention
            query = query.unsqueeze(2)  # (batch_size, num_queries, 1, query_dim)
            key = key.unsqueeze(1)      # (batch_size, 1, num_keys, key_dim)
            
            # Compute attention scores
            energy = self.attention(torch.cat([query, key], dim=-1))
            energy = energy.squeeze(-1)  # (batch_size, num_queries, num_keys)
            
        else:
            # Multiplicative (dot-product) attention
            energy = torch.matmul(query, key.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(energy, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

## 2. Visualizing Attention

Let's create a function to visualize attention weights:

In [None]:
def plot_attention_weights(
    attention_weights: torch.Tensor,
    x_labels: Optional[list] = None,
    y_labels: Optional[list] = None,
    title: str = "Attention Weights"
) -> None:
    """Plot attention weights as a heatmap."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights.detach().cpu().numpy(),
        xticklabels=x_labels,
        yticklabels=y_labels,
        cmap='viridis'
    )
    plt.title(title)
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.show()

## 3. Example: Machine Translation

Let's demonstrate attention in a simple machine translation scenario:

In [None]:
# Create sample data
batch_size = 1
seq_len = 5
hidden_dim = 8

# Source sequence (e.g., English)
source = torch.randn(batch_size, seq_len, hidden_dim)

# Target sequence (e.g., French)
target = torch.randn(batch_size, seq_len, hidden_dim)

# Initialize attention mechanism
attention = BasicAttention(
    query_dim=hidden_dim,
    key_dim=hidden_dim,
    value_dim=hidden_dim,
    attention_type='dot'
)

# Compute attention
output, attention_weights = attention(target, source, source)

# Visualize attention weights
plot_attention_weights(
    attention_weights[0],
    x_labels=[f'Source {i+1}' for i in range(seq_len)],
    y_labels=[f'Target {i+1}' for i in range(seq_len)],
    title='Attention Weights in Machine Translation'
)

## 4. Comparing Additive and Multiplicative Attention

Let's compare the two types of attention:

In [None]:
# Initialize both types of attention
additive_attention = BasicAttention(
    query_dim=hidden_dim,
    key_dim=hidden_dim,
    value_dim=hidden_dim,
    attention_type='additive'
)

multiplicative_attention = BasicAttention(
    query_dim=hidden_dim,
    key_dim=hidden_dim,
    value_dim=hidden_dim,
    attention_type='dot'
)

# Compute attention with both mechanisms
_, additive_weights = additive_attention(target, source, source)
_, multiplicative_weights = multiplicative_attention(target, source, source)

# Visualize both
plt.figure(figsize=(15, 6))

plt.subplot(1, 2, 1)
sns.heatmap(additive_weights[0].detach().cpu().numpy(), cmap='viridis')
plt.title('Additive Attention')
plt.xlabel('Source')
plt.ylabel('Target')

plt.subplot(1, 2, 2)
sns.heatmap(multiplicative_weights[0].detach().cpu().numpy(), cmap='viridis')
plt.title('Multiplicative Attention')
plt.xlabel('Source')
plt.ylabel('Target')

plt.tight_layout()
plt.show()

## 5. Conclusion

In this notebook, we've explored:

1. Basic attention mechanisms (additive and multiplicative)
2. Implementation of attention in PyTorch
3. Visualization of attention weights
4. Comparison of different attention types

Key takeaways:

- Additive attention uses a feed-forward network to compute attention scores
- Multiplicative attention uses dot products, making it more efficient
- Both types can be effective, with multiplicative attention being more commonly used in practice
- Attention weights provide interpretability into model decisions