# MAIRA-2 Attention Visualization on Google Colab

This notebook helps you visualize attention patterns in MAIRA-2 using free GPU resources.

**⚠️ Important Setup:**
1. Go to **Runtime > Change runtime type**
2. Set **Hardware accelerator** to **GPU**
3. Click **Save**
4. You'll need a Hugging Face token from https://huggingface.co/settings/tokens

## 1. Check GPU Availability

In [None]:
import torch
import subprocess

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("❌ No GPU detected! Go to Runtime > Change runtime type > GPU")

# Check available RAM
result = subprocess.run(['free', '-h'], capture_output=True, text=True)
print("\nSystem Memory:")
print(result.stdout)

## 2. Install Dependencies

In [None]:
# Install required packages
!pip install torch torchvision transformers accelerate
!pip install pillow matplotlib tqdm requests
!pip install huggingface-hub

## 3. Get the Code

In [None]:
# Clone the repository
!git clone https://github.com/javier-alvarez/maira-interp.git
%cd maira-interp

# Verify files
!ls -la

## 4. Set Up Authentication

**You need a Hugging Face token to access MAIRA-2:**
1. Go to https://huggingface.co/settings/tokens
2. Create a new token with 'Read' permissions
3. Request access to MAIRA-2: https://huggingface.co/microsoft/maira-2
4. Enter your token in the cell below

In [None]:
import os
from getpass import getpass

# Enter your Hugging Face token (it will be hidden)
hf_token = getpass("Enter your Hugging Face token: ")
os.environ['HF_TOKEN'] = hf_token

print("✅ Token set!")

# Test authentication
from huggingface_hub import HfApi
try:
    api = HfApi()
    user = api.whoami(token=hf_token)
    print(f"✅ Authenticated as: {user['name']}")
except Exception as e:
    print(f"❌ Authentication failed: {e}")

## 5. Test Import

In [None]:
# Test importing the visualizer
try:
    from attention_visualizer import MAIRA2AttentionVisualizer
    print("✅ Successfully imported MAIRA2AttentionVisualizer")
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("Current directory:", os.getcwd())
    print("Files:", os.listdir('.'))

## 6. Memory Optimization Setup

**Important:** MAIRA-2 requires ~15GB GPU memory. Free Colab GPUs have 16GB, so we need to be careful.

In [None]:
import gc
import torch

def clear_memory():
    """Clear GPU and system memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def check_memory():
    """Check current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.1f}GB, Reserved: {reserved:.1f}GB, Total: {total:.1f}GB")
        return allocated, reserved, total
    return 0, 0, 0

# Clear memory before starting
clear_memory()
check_memory()

## 7. Initialize MAIRA-2 with Memory Management

**⚠️ Critical:** MAIRA-2 needs ~15GB. Your GPU has 14.7GB - we'll use aggressive optimizations.

In [None]:
import torch
print("🚀 Initializing MAIRA-2 Attention Visualizer...")
print("This will download ~13GB of model weights on first run.")
print("Please be patient - this may take 5-10 minutes.")

# Check available memory before loading
allocated, reserved, total = check_memory()
if total < 15:
    print(f"⚠️  Warning: Only {total:.1f}GB GPU memory available. MAIRA-2 needs ~15GB.")
    print("Using memory optimization strategies...")

try:
    # Memory optimization for smaller GPUs
    if total < 15:
        print("🔧 Loading with memory optimizations...")
        visualizer = MAIRA2AttentionVisualizer()
        
        # Move model to half precision to save memory
        if hasattr(visualizer, 'model'):
            visualizer.model = visualizer.model.half()
        
        print("✅ MAIRA-2 loaded with memory optimizations!")
    else:
        visualizer = MAIRA2AttentionVisualizer()
        print("✅ MAIRA-2 loaded successfully!")
    
    check_memory()
    
except Exception as e:
    print(f"❌ Failed to load MAIRA-2: {e}")
    print("\n🔧 Trying aggressive memory optimization...")
    clear_memory()
    
    try:
        # Last resort: try to load with minimal memory
        print("Attempting to load with CPU offloading...")
        
        # This is a fallback - may not work but worth trying
        import os
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
        
        visualizer = MAIRA2AttentionVisualizer()
        print("✅ MAIRA-2 loaded with CPU offloading!")
        check_memory()
        
    except Exception as e2:
        print(f"❌ All loading attempts failed: {e2}")
        print("\n💡 Solutions:")
        print("1. Runtime > Restart Runtime and try again")
        print("2. Try Colab Pro for more GPU memory")
        print("3. Use a different GPU type if available")
        print("4. Try running at a different time when GPUs are less loaded")
        
        # Set visualizer to None so notebook doesn't crash
        visualizer = None

## 8. Download Sample Images

In [None]:
import requests
from PIL import Image
from io import BytesIO

# Download sample chest X-ray images
def download_image(url, filename):
    try:
        response = requests.get(url)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        image.save(filename)
        print(f"✅ Downloaded {filename}")
        return image
    except Exception as e:
        print(f"❌ Failed to download {filename}: {e}")
        return None

# Sample chest X-ray URLs (public domain)
frontal_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
lateral_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"

frontal_image = download_image(frontal_url, "sample_frontal.png")
lateral_image = download_image(lateral_url, "sample_lateral.png")

# Display the images
if frontal_image:
    print("\nFrontal X-ray:")
    display(frontal_image.resize((256, 256)))
    
if lateral_image:
    print("\nLateral X-ray:")
    display(lateral_image.resize((256, 256)))

## 9. Generate Attention Visualizations (Memory-Safe)

**Ultra-conservative settings to prevent OOM crashes:**

In [None]:
print("🎯 Starting attention visualization...")

# Check if visualizer loaded successfully
if 'visualizer' not in globals() or visualizer is None:
    print("❌ MAIRA-2 not loaded. Please run the model loading cell first.")
else:
    check_memory()

    try:
        # Very conservative settings for 14.7GB GPU
        output_dir, generated_report = visualizer.generate_attention_pngs(
            frontal_image=frontal_image,
            lateral_image=None,  # Skip lateral to save ~2GB memory
            indication="Shortness of breath and chest pain",
            technique="PA chest X-ray",
            comparison="No prior studies available",
            max_new_tokens=10,   # Very low to prevent OOM
            visualize_every_n=10, # Only visualize every 10th token
            output_dir="colab_attention_output"
        )
        
        print(f"\n✅ Visualizations completed!")
        print(f"📁 Output directory: {output_dir}")
        print(f"📄 Generated report: {generated_report}")
        
    except Exception as e:
        print(f"❌ Error during visualization: {e}")
        print("This is likely a GPU memory issue.")
        
        # Try emergency cleanup and retry with even smaller settings
        clear_memory()
        print("\n🔧 Trying with minimal settings...")
        
        try:
            output_dir, generated_report = visualizer.generate_attention_pngs(
                frontal_image=frontal_image,
                lateral_image=None,
                indication="Chest pain",
                max_new_tokens=5,    # Absolute minimum
                visualize_every_n=20, # Almost no visualizations
                output_dir="colab_attention_output"
            )
            print(f"\n✅ Minimal visualization completed!")
            print(f"📄 Generated report: {generated_report}")
            
        except Exception as e2:
            print(f"❌ Even minimal settings failed: {e2}")
            print("\n💡 Your GPU doesn't have enough memory for MAIRA-2.")
            print("Try:")
            print("1. Colab Pro (better GPUs)")
            print("2. Different time of day")
            print("3. Kaggle Notebooks (might have more memory)")

    check_memory()

## 10. View Results

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt

# List generated files
output_dir = "colab_attention_output"
if os.path.exists(output_dir):
    files = os.listdir(output_dir)
    print(f"📁 Generated files ({len(files)} total):")
    for file in sorted(files)[:10]:  # Show first 10 files
        print(f"  - {file}")
    if len(files) > 10:
        print(f"  ... and {len(files) - 10} more files")
    
    # Show the generated report
    report_file = os.path.join(output_dir, "generated_report.txt")
    if os.path.exists(report_file):
        with open(report_file, 'r') as f:
            report = f.read()
        print(f"\n📄 Generated Report:")
        print("=" * 50)
        print(report)
        print("=" * 50)
    
    # Display a few attention visualizations
    png_files = [f for f in files if f.endswith('.png')][:3]  # Show first 3
    
    for png_file in png_files:
        try:
            img_path = os.path.join(output_dir, png_file)
            img = Image.open(img_path)
            print(f"\n🎯 {png_file}:")
            display(img)
        except Exception as e:
            print(f"❌ Could not display {png_file}: {e}")
else:
    print(f"❌ Output directory '{output_dir}' not found")

## 11. Download Results

Create a zip file to download all results:

In [None]:
import zipfile
import os

# Create a zip file with all results
zip_filename = "maira2_attention_results.zip"
output_dir = "colab_attention_output"

if os.path.exists(output_dir):
    with zipfile.ZipFile(zip_filename, 'w') as zipf:
        for root, dirs, files in os.walk(output_dir):
            for file in files:
                file_path = os.path.join(root, file)
                zipf.write(file_path, os.path.relpath(file_path, '.'))
    
    print(f"✅ Created {zip_filename}")
    print(f"📁 Size: {os.path.getsize(zip_filename) / 1024**2:.1f} MB")
    print("\n📥 To download: Click the folder icon (🗂️) on the left, find the zip file, and download it")
else:
    print("❌ No results to zip")

## 12. Clean Up Memory

Run this when you're done to free up memory:

In [None]:
# Clean up to free memory
try:
    del visualizer
except:
    pass

clear_memory()
print("✅ Memory cleaned up")
check_memory()

## 🔧 Troubleshooting

**Your GPU (14.7GB) is right at the limit for MAIRA-2 (~15GB needed).**

**If you get OOM crashes:**

1. **Immediate fixes:**
   - Runtime > Restart Runtime
   - Run cells 1-6, then try loading again
   - Close other browser tabs to free system RAM

2. **Try different times:**
   - Early morning or late night (less Colab usage)
   - Weekends might have better GPU availability

3. **Alternative platforms:**
   - **Kaggle Notebooks** (often more generous with memory)
   - **Colab Pro** ($10/month - guaranteed better GPUs)
   - **Paperspace Gradient** (free tier available)

4. **Memory optimization that worked:**
   - Only frontal X-ray (no lateral)
   - `max_new_tokens=5` (absolute minimum)
   - `visualize_every_n=20` (few visualizations)

**Expected behavior:**
- Model loading: Uses ~13-14GB
- Generation: Adds 1-2GB temporarily
- Your 14.7GB GPU is borderline - success depends on exact memory fragmentation

**Success tips:**
- Fresh runtime (no previous models loaded)
- Minimal browser tabs open
- Try 2-3 times if first attempt fails