# Self-Attention Mechanism

This notebook explores self-attention, a fundamental component of Transformer models. We'll cover:

1. The intuition behind self-attention
2. Implementation of self-attention
3. Visualizing self-attention patterns
4. Real-world examples

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Tuple

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Self-Attention Implementation

Let's implement a self-attention mechanism:

In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self,
        input_dim: int,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.input_dim = input_dim
        
        # Linear projections for Q, K, V
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([input_dim]))
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
            mask: Optional mask tensor of shape (batch_size, seq_len, seq_len)
            
        Returns:
            Tuple of (output, attention_weights)
        """
        batch_size = x.shape[0]
        
        # Project input to Q, K, V
        Q = self.query(x)  # (batch_size, seq_len, input_dim)
        K = self.key(x)    # (batch_size, seq_len, input_dim)
        V = self.value(x)  # (batch_size, seq_len, input_dim)
        
        # Compute attention scores
        energy = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(energy, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attention_weights, V)
        
        return output, attention_weights

## 2. Visualizing Self-Attention

Let's create functions to visualize self-attention patterns:

In [None]:
def plot_self_attention(
    attention_weights: torch.Tensor,
    tokens: Optional[list] = None,
    title: str = "Self-Attention Weights"
) -> None:
    """Plot self-attention weights as a heatmap."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights.detach().cpu().numpy(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='viridis'
    )
    plt.title(title)
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    plt.show()

def plot_attention_flow(
    attention_weights: torch.Tensor,
    tokens: list,
    title: str = "Attention Flow"
) -> None:
    """Plot attention flow between tokens using a directed graph."""
    import networkx as nx
    
    # Create directed graph
    G = nx.DiGraph()
    
    # Add nodes
    for i, token in enumerate(tokens):
        G.add_node(i, label=token)
    
    # Add edges with weights
    weights = attention_weights.detach().cpu().numpy()
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            if weights[i, j] > 0.1:  # Only show significant connections
                G.add_edge(i, j, weight=weights[i, j])
    
    # Create figure
    plt.figure(figsize=(12, 8))
    
    # Draw graph
    pos = nx.spring_layout(G)
    nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=1000)
    nx.draw_networkx_edges(G, pos, edge_color='gray', width=1, alpha=0.5)
    nx.draw_networkx_labels(G, pos, {i: token for i, token in enumerate(tokens)})
    
    plt.title(title)
    plt.axis('off')
    plt.show()

## 3. Example: Text Processing

Let's demonstrate self-attention on a simple text example:

In [None]:
# Create sample data
batch_size = 1
seq_len = 6
hidden_dim = 8

# Sample sentence: "The cat sat on the mat"
tokens = ["The", "cat", "sat", "on", "the", "mat"]

# Create random embeddings for demonstration
x = torch.randn(batch_size, seq_len, hidden_dim)

# Initialize self-attention
self_attention = SelfAttention(input_dim=hidden_dim)

# Compute self-attention
output, attention_weights = self_attention(x)

# Visualize attention weights
plot_self_attention(
    attention_weights[0],
    tokens=tokens,
    title='Self-Attention in "The cat sat on the mat"'
)

# Visualize attention flow
plot_attention_flow(
    attention_weights[0],
    tokens=tokens,
    title='Attention Flow in "The cat sat on the mat"'
)

## 4. Real-World Example: Using Pre-trained Model

Let's examine self-attention in a pre-trained model:

In [None]:
from transformers import BertTokenizer, BertModel

# Load pre-trained model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Example sentence
text = "The quick brown fox jumps over the lazy dog"

# Tokenize and get model output
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs, output_attentions=True)

# Get attention weights from the first layer
attention_weights = outputs.attentions[0][0]  # First batch, first layer

# Get tokens
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

# Visualize attention
plot_self_attention(
    attention_weights[0],  # First attention head
    tokens=tokens,
    title='BERT Self-Attention (First Head)'
)

## 5. Conclusion

In this notebook, we've explored:

1. Implementation of self-attention
2. Visualization of attention patterns
3. Application to text processing
4. Real-world example using BERT

Key takeaways:

- Self-attention allows each position to attend to all positions
- It helps capture long-range dependencies in sequences
- Attention weights provide interpretability
- Different attention heads can learn different patterns