# Introduction to Attention Mechanisms

This notebook introduces the fundamental concepts of attention mechanisms in deep learning. We'll explore:

1. Why attention is needed
2. The basic mathematics behind attention
3. A simple implementation from scratch
4. Visualizing how attention works

## Why Attention?

Traditional sequence models like RNNs and LSTMs have several limitations:

1. **Information Bottleneck**: All information must pass through a fixed-size hidden state
2. **Long-term Dependencies**: Difficulty in capturing relationships between distant elements
3. **Parallelization**: Sequential processing makes it hard to parallelize

Attention mechanisms address these issues by:

- Allowing direct access to any part of the input sequence
- Computing relevance scores between elements
- Enabling parallel processing of the entire sequence

## The Math Behind Attention

The core of attention is computing relevance scores between queries and keys, then using these scores to weight the values:

1. **Query-Key-Value Triplet**:
   - Query (Q): What we're looking for
   - Key (K): What we're matching against
   - Value (V): What we're retrieving

2. **Attention Scores**:
   $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where $d_k$ is the dimension of the key vectors.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Simple Attention Implementation

Let's implement a basic attention mechanism from scratch:

In [None]:
class SimpleAttention(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.input_dim = input_dim
        self.scale = math.sqrt(input_dim)
        
    def forward(self, query, key, value, mask=None):
        # Compute attention scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Compute attention weights
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention weights to values
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

## Visualizing Attention

Let's create a simple example to visualize how attention works:

In [None]:
def visualize_attention(attention_weights, labels):
    plt.figure(figsize=(10, 8))
    sns.heatmap(attention_weights, 
                xticklabels=labels,
                yticklabels=labels,
                cmap='viridis')
    plt.title('Attention Weights')
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.show()

# Create a simple example
sequence_length = 5
embedding_dim = 4

# Generate random embeddings
query = torch.randn(1, sequence_length, embedding_dim)
key = torch.randn(1, sequence_length, embedding_dim)
value = torch.randn(1, sequence_length, embedding_dim)

# Create attention mechanism
attention = SimpleAttention(embedding_dim)

# Compute attention
output, attention_weights = attention(query, key, value)

# Visualize attention weights
labels = [f'Token {i+1}' for i in range(sequence_length)]
visualize_attention(attention_weights[0].detach().numpy(), labels)

## Real-World Example: Machine Translation

Let's see how attention helps in machine translation by visualizing the attention weights between source and target words:

In [None]:
from transformers import MarianMTModel, MarianTokenizer

# Load pre-trained model and tokenizer
model_name = 'Helsinki-NLP/opus-mt-en-fr'
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

# Example translation
text = "The cat sat on the mat."
inputs = tokenizer(text, return_tensors="pt")

# Get translation and attention weights
outputs = model.generate(**inputs, output_attentions=True)
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"English: {text}")
print(f"French: {translation}")

# Visualize attention weights from the last layer
attention_weights = outputs.attentions[-1][0, 0].mean(dim=0).detach().numpy()
source_tokens = tokenizer.tokenize(text)
target_tokens = tokenizer.tokenize(translation)

plt.figure(figsize=(12, 8))
sns.heatmap(attention_weights, 
            xticklabels=source_tokens,
            yticklabels=target_tokens,
            cmap='viridis')
plt.title('Attention Weights in Translation')
plt.xlabel('Source Tokens')
plt.ylabel('Target Tokens')
plt.show()

## Conclusion

In this notebook, we've explored:

1. The motivation behind attention mechanisms
2. The mathematical foundation of attention
3. A simple implementation from scratch
4. Visualization of attention weights
5. A real-world example in machine translation

In the next notebook, we'll explore different types of attention mechanisms and their specific applications.