# MedViT Attention Map Visualization

This notebook demonstrates how to visualize attention maps from trained MedViT models.

MedViT uses two types of attention mechanisms:
1. **LFP (Local Feature Processing)**: Neighborhood Attention or Standard Multi-Head Attention
2. **GFP (Global Feature Processing)**: E-MHSA (Efficient Multi-Head Self Attention)

In [None]:
import sys
sys.path.insert(0, '../src')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

from visualize_attention import AttentionVisualizer, load_model

# Auto-reload modules for development
%load_ext autoreload
%autoreload 2

## 1. Setup

In [None]:
# Configuration
MODEL_NAME = 'MedViT_tiny'
CHECKPOINT_PATH = '../checkpoint/MedViT_tiny_brain_tumor.pth'  # Update this path
NUM_CLASSES = 4  # Brain tumor dataset: glioma, meningioma, pituitary, no-tumor

# Class names for brain tumor dataset
CLASS_NAMES = ['glioma', 'meningioma', 'no-tumor', 'pituitary']

# Device setup
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f"Using device: {device}")

## 2. Load Model

In [None]:
# Load the trained model
model = load_model(
    model_name=MODEL_NAME,
    checkpoint_path=CHECKPOINT_PATH,
    num_classes=NUM_CLASSES,
    device=device
)

print(f"Model loaded: {MODEL_NAME}")

## 3. Create Attention Visualizer

In [None]:
# Create the visualizer
visualizer = AttentionVisualizer(model, device)

print("Visualizer created successfully!")

## 4. Load Sample Image from Dataset

In [None]:
from datasets import load_dataset
from torchvision import transforms

# Load the brain tumor dataset
dataset = load_dataset('PranomVignesh/MRI-Images-of-Brain-Tumor', split='test')

# Get a sample image
sample_idx = 0  # Change this to view different samples
sample = dataset[sample_idx]
sample_image = sample['image']
sample_label = sample['label']

print(f"Sample {sample_idx}")
print(f"Label: {sample_label} ({CLASS_NAMES[sample_label]})")

# Display the original image
plt.figure(figsize=(6, 6))
plt.imshow(sample_image)
plt.title(f"Original Image\nLabel: {CLASS_NAMES[sample_label]}")
plt.axis('off')
plt.show()

## 5. Visualize Attention Maps

In [None]:
# Save the sample image temporarily for visualization
temp_image_path = '/tmp/sample_brain_mri.png'
sample_image.save(temp_image_path)

# Visualize attention maps for all layers
attention_maps, pred_class = visualizer.visualize(
    image_path=temp_image_path,
    output_path=None,  # Set to a path to save the figure
    show=True,
    cmap='jet',
    alpha=0.5
)

print(f"\nPredicted class: {pred_class} ({CLASS_NAMES[pred_class]})")
print(f"True label: {sample_label} ({CLASS_NAMES[sample_label]})")
print(f"\nCaptured {len(attention_maps)} attention layers")

## 6. Visualize Individual Attention Heads

In [None]:
# List available attention layers
print("Available attention layers:")
for i, name in enumerate(attention_maps.keys()):
    print(f"  {i}: {name}")

In [None]:
# Visualize all heads for a specific layer
# Choose a layer from the list above
layer_to_visualize = list(attention_maps.keys())[0]  # First layer

visualizer.visualize_all_heads(
    image_path=temp_image_path,
    layer_name=layer_to_visualize,
    show=True,
    cmap='jet',
    alpha=0.5
)

## 7. Visualize Multiple Samples

In [None]:
def visualize_sample(dataset, idx, visualizer, class_names):
    """Visualize attention for a specific sample."""
    sample = dataset[idx]
    sample_image = sample['image']
    sample_label = sample['label']
    
    # Save temp image
    temp_path = f'/tmp/sample_{idx}.png'
    sample_image.save(temp_path)
    
    # Get attention maps
    attention_maps, img_np, pred_class = visualizer.get_attention_maps(temp_path)
    
    # Create visualization
    fig, axes = plt.subplots(1, min(4, len(attention_maps) + 1), figsize=(16, 4))
    
    # Original image
    axes[0].imshow(img_np)
    correct = "✓" if pred_class == sample_label else "✗"
    axes[0].set_title(f'Original\nTrue: {class_names[sample_label]}\nPred: {class_names[pred_class]} {correct}')
    axes[0].axis('off')
    
    # Attention maps
    for i, (name, attn) in enumerate(list(attention_maps.items())[:3]):
        attn_map = visualizer.aggregate_attention(attn, method='mean')
        attn_2d = visualizer.reshape_attention_to_image(attn_map, (224, 224))
        attn_2d = (attn_2d - attn_2d.min()) / (attn_2d.max() - attn_2d.min() + 1e-8)
        
        axes[i+1].imshow(img_np)
        axes[i+1].imshow(attn_2d, cmap='jet', alpha=0.5)
        short_name = name.split('.')[-1]
        axes[i+1].set_title(f'Layer {i}: {short_name}')
        axes[i+1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize a few samples from each class
for class_idx in range(NUM_CLASSES):
    # Find a sample of this class
    for i in range(len(dataset)):
        if dataset[i]['label'] == class_idx:
            print(f"\n=== Class: {CLASS_NAMES[class_idx]} ===")
            visualize_sample(dataset, i, visualizer, CLASS_NAMES)
            break

## 8. Custom Attention Analysis

In [None]:
# Get attention maps for the current sample
attention_maps, img_np, pred_class = visualizer.get_attention_maps(temp_image_path)

# Analyze attention statistics
print("Attention Statistics:")
print("-" * 50)
for name, attn in attention_maps.items():
    print(f"\nLayer: {name}")
    print(f"  Shape: {attn.shape}")
    print(f"  Min: {attn.min().item():.4f}")
    print(f"  Max: {attn.max().item():.4f}")
    print(f"  Mean: {attn.mean().item():.4f}")
    print(f"  Std: {attn.std().item():.4f}")

In [None]:
# Visualize attention distribution
fig, axes = plt.subplots(1, len(attention_maps), figsize=(4 * len(attention_maps), 4))
if len(attention_maps) == 1:
    axes = [axes]

for ax, (name, attn) in zip(axes, attention_maps.items()):
    # Flatten and plot histogram
    attn_flat = attn.flatten().numpy()
    ax.hist(attn_flat, bins=50, edgecolor='black', alpha=0.7)
    ax.set_title(f'{name.split(".")[-1]}')
    ax.set_xlabel('Attention Weight')
    ax.set_ylabel('Frequency')

plt.suptitle('Attention Weight Distribution')
plt.tight_layout()
plt.show()