# Attention in Computer Vision

This notebook explores how attention mechanisms are applied in computer vision tasks. We'll cover:

1. Vision Transformers (ViT)
2. Attention in image classification
3. Visualizing attention maps
4. Real-world applications

## Why Attention in Vision?

Traditional CNNs have limitations:

1. **Local Receptive Fields**: CNNs process images through local convolutions
2. **Fixed Architecture**: The architecture is predetermined by the network design
3. **Limited Global Context**: Capturing long-range dependencies requires deep networks

Attention in vision addresses these by:

- Enabling direct modeling of relationships between any image regions
- Providing interpretable attention maps
- Allowing flexible architecture design

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Vision Transformer (ViT) Overview

The Vision Transformer (ViT) applies the Transformer architecture to images by:

1. **Image Patching**: Dividing the image into fixed-size patches
2. **Linear Projection**: Flattening and projecting patches into tokens
3. **Position Embeddings**: Adding learnable position embeddings
4. **Transformer Encoder**: Processing tokens through self-attention layers

In [None]:
def load_and_preprocess_image(image_path):
    """Load and preprocess an image for ViT."""
    # Load image
    image = Image.open(image_path).convert('RGB')
    
    # Define preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Apply preprocessing
    image_tensor = transform(image).unsqueeze(0)
    
    return image, image_tensor

def visualize_attention_map(image, attention_weights, patch_size=16):
    """Visualize attention weights as a heatmap over the image."""
    # Reshape attention weights to match image patches
    h, w = image.size[1] // patch_size, image.size[0] // patch_size
    attention_map = attention_weights.reshape(h, w)
    
    # Create figure
    plt.figure(figsize=(12, 4))
    
    # Plot original image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    # Plot attention map
    plt.subplot(1, 2, 2)
    plt.imshow(image)
    plt.imshow(attention_map, alpha=0.5, cmap='jet')
    plt.title('Attention Map')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Loading a Pre-trained ViT Model

Let's load a pre-trained ViT model and examine its attention patterns:

In [None]:
# Load pre-trained ViT model
model_name = 'google/vit-base-patch16-224'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Set model to evaluation mode
model.eval()

def get_attention_maps(model, image_tensor):
    """Extract attention maps from the model."""
    with torch.no_grad():
        outputs = model(image_tensor, output_attentions=True)
        
    # Get attention weights from the last layer
    attention_weights = outputs.attentions[-1][0, 0]  # First batch, first head
    
    # Average attention weights for the [CLS] token
    cls_attention = attention_weights[0, 1:].mean(dim=0)
    
    return cls_attention

## Visualizing Attention in Image Classification

Let's analyze how the model attends to different parts of an image:

In [None]:
# Load and process an example image
image_path = 'path_to_your_image.jpg'  # Replace with your image path
image, image_tensor = load_and_preprocess_image(image_path)

# Get attention maps
attention_weights = get_attention_maps(model, image_tensor)

# Visualize attention
visualize_attention_map(image, attention_weights)

## Analyzing Different Attention Heads

Different attention heads in ViT can focus on different aspects of the image:

In [None]:
def visualize_multi_head_attention(model, image_tensor, num_heads=4):
    """Visualize attention patterns from multiple heads."""
    with torch.no_grad():
        outputs = model(image_tensor, output_attentions=True)
    
    # Get attention weights from the last layer
    attention_weights = outputs.attentions[-1][0]  # First batch
    
    # Create figure for visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    axes = axes.flatten()
    
    # Visualize attention for selected heads
    for i in range(num_heads):
        # Get attention weights for the [CLS] token
        cls_attention = attention_weights[i, 0, 1:].mean(dim=0)
        
        # Reshape to match image patches
        h, w = 14, 14  # For 224x224 images with 16x16 patches
        attention_map = cls_attention.reshape(h, w)
        
        # Plot attention map
        axes[i].imshow(image)
        axes[i].imshow(attention_map, alpha=0.5, cmap='jet')
        axes[i].set_title(f'Head {i+1}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize multiple attention heads
visualize_multi_head_attention(model, image_tensor)

## Real-World Applications

Attention mechanisms in vision have enabled several breakthroughs:

1. **Image Classification**: ViT achieves state-of-the-art results
2. **Object Detection**: DETR uses attention for end-to-end object detection
3. **Image Generation**: Attention helps in generating high-quality images
4. **Medical Imaging**: Attention helps focus on relevant regions in medical images

## Conclusion

In this notebook, we've explored:

1. How attention mechanisms are applied in computer vision
2. The Vision Transformer architecture
3. Visualizing attention patterns in images
4. Real-world applications of attention in vision

Key takeaways:

- Attention provides a powerful way to model relationships in images
- Different attention heads can focus on different aspects of the image
- Attention maps provide interpretable insights into model decisions

In the next notebook, we'll explore attention in other domains like audio and multimodal systems.