# Attention in Machine Translation

This notebook explores how attention mechanisms revolutionized machine translation. We'll cover:

1. The evolution of attention in translation
2. Implementation of attention-based translation
3. Visualizing attention patterns
4. Real-world examples and analysis

## The Evolution of Attention in Translation

Machine translation has evolved through several stages:

1. **Statistical Machine Translation (SMT)**:
   - Rule-based systems
   - Phrase-based translation
   - Limited context understanding

2. **Neural Machine Translation (NMT) with RNNs**:
   - Encoder-decoder architecture
   - Fixed-size context vector
   - Limited long-range dependencies

3. **Attention-based NMT**:
   - Dynamic context vector
   - Direct access to source words
   - Better handling of long sequences

4. **Transformer-based NMT**:
   - Self-attention mechanism
   - Parallel processing
   - State-of-the-art performance

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import MarianMTModel, MarianTokenizer
from typing import List, Tuple

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Implementing Attention-based Translation

Let's implement a simple attention-based translation model:

In [None]:
class AttentionTranslation(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, bidirectional=True)
        self.decoder = nn.LSTM(hidden_dim * 2, hidden_dim)
        self.attention = nn.Linear(hidden_dim * 3, 1)
        self.output = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Encode source sequence
        encoder_outputs, (hidden, cell) = self.encoder(src)
        
        # Initialize decoder hidden state
        decoder_hidden = hidden[-1].unsqueeze(0)
        decoder_cell = cell[-1].unsqueeze(0)
        
        # Initialize attention weights storage
        attention_weights = []
        outputs = []
        
        # Decode target sequence
        for t in range(tgt.size(1)):
            # Compute attention scores
            attention_input = torch.cat([
                decoder_hidden.repeat(encoder_outputs.size(1), 1, 1).transpose(0, 1),
                encoder_outputs
            ], dim=2)
            attention_scores = self.attention(attention_input).squeeze(2)
            attention_weights_t = F.softmax(attention_scores, dim=1)
            attention_weights.append(attention_weights_t)
            
            # Compute context vector
            context = torch.bmm(attention_weights_t.unsqueeze(1), encoder_outputs)
            
            # Decode one step
            decoder_input = torch.cat([tgt[:, t:t+1], context], dim=2)
            output, (decoder_hidden, decoder_cell) = self.decoder(decoder_input, (decoder_hidden, decoder_cell))
            
            # Project to output vocabulary
            output = self.output(output)
            outputs.append(output)
        
        return torch.cat(outputs, dim=1), torch.stack(attention_weights, dim=1)

## Visualizing Attention Patterns

Let's create functions to visualize attention patterns in translation:

In [None]:
def plot_attention_alignment(
    attention_weights: torch.Tensor,
    source_tokens: List[str],
    target_tokens: List[str],
    title: str = "Attention Alignment"
) -> None:
    """Plot attention alignment between source and target tokens."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=source_tokens,
        yticklabels=target_tokens,
        cmap='viridis'
    )
    plt.title(title)
    plt.xlabel('Source Tokens')
    plt.ylabel('Target Tokens')
    plt.show()

def analyze_translation_attention(
    model: MarianMTModel,
    tokenizer: MarianTokenizer,
    text: str,
    layer_idx: int = -1
) -> None:
    """Analyze attention patterns in a translation model."""
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    
    # Generate translation with attention
    outputs = model.generate(**inputs, output_attentions=True)
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Source: {text}")
    print(f"Target: {translation}")
    
    # Get attention weights
    attention_weights = outputs.attentions[layer_idx][0, 0]  # First batch, first head
    source_tokens = tokenizer.tokenize(text)
    target_tokens = tokenizer.tokenize(translation)
    
    # Plot attention alignment
    plot_attention_alignment(
        attention_weights,
        source_tokens,
        target_tokens,
        title=f"Attention Alignment (Layer {layer_idx})"
    )

## Real-World Example: MarianMT

Let's analyze attention patterns in a pre-trained MarianMT model:

In [None]:
# Load pre-trained model and tokenizer
model_name = 'Helsinki-NLP/opus-mt-en-fr'
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)

# Example translations
examples = [
    "The cat sat on the mat.",
    "Attention mechanisms have revolutionized machine translation.",
    "Neural networks can learn complex patterns from data."
]

# Analyze attention for each example
for text in examples:
    analyze_translation_attention(model, tokenizer, text)

## Analyzing Different Layers

Let's examine how attention patterns vary across different layers:

In [None]:
def analyze_layer_attention(
    model: MarianMTModel,
    tokenizer: MarianTokenizer,
    text: str,
    num_layers: int = 3
) -> None:
    """Analyze attention patterns across different layers."""
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    
    # Generate translation with attention
    outputs = model.generate(**inputs, output_attentions=True)
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Source: {text}")
    print(f"Target: {translation}")
    
    # Get attention weights for different layers
    source_tokens = tokenizer.tokenize(text)
    target_tokens = tokenizer.tokenize(translation)
    
    # Plot attention for each layer
    for layer_idx in range(-num_layers, 0):
        attention_weights = outputs.attentions[layer_idx][0, 0]
        plot_attention_alignment(
            attention_weights,
            source_tokens,
            target_tokens,
            title=f"Attention Alignment (Layer {layer_idx})"
        )

# Analyze layer attention for an example
text = "The cat sat on the mat and watched the bird."
analyze_layer_attention(model, tokenizer, text)

## Conclusion

In this notebook, we've explored:

1. The evolution of attention in machine translation
2. Implementation of an attention-based translation model
3. Visualization of attention patterns
4. Analysis of real-world translation models

Key takeaways:

- Attention mechanisms have revolutionized machine translation
- Different layers capture different aspects of the translation process
- Attention patterns provide insights into how the model makes decisions

In the next notebook, we'll explore attention in other domains like audio and multimodal systems.