# MAIRA-2 Lightweight Demo for Memory-Constrained GPUs

**🎯 Designed specifically for 14-15GB GPUs (Colab T4)**

This notebook uses aggressive memory optimizations to run MAIRA-2 on limited GPU memory.

**Setup:**
1. Runtime > Change runtime type > GPU > Save
2. Get HF token: https://huggingface.co/settings/tokens
3. Request MAIRA-2 access: https://huggingface.co/microsoft/maira-2

In [None]:
# Check environment
import torch
import subprocess

print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_props = torch.cuda.get_device_properties(0)
    gpu_memory_gb = gpu_props.total_memory / 1024**3
    print(f"GPU: {gpu_props.name}")
    print(f"Memory: {gpu_memory_gb:.1f}GB")
    
    if gpu_memory_gb < 14:
        print("❌ GPU has insufficient memory for MAIRA-2")
        print("Try: Runtime > Change runtime type > Select different GPU")
    elif gpu_memory_gb < 16:
        print("⚠️  GPU memory is at the minimum. Using ultra-lightweight mode.")
    else:
        print("✅ GPU memory sufficient")
else:
    print("❌ No GPU! Go to Runtime > Change runtime type > GPU")

In [None]:
# Install only essential packages
!pip install -q torch torchvision
!pip install -q transformers==4.44.0 accelerate
!pip install -q pillow matplotlib requests
print("✅ Packages installed")

In [None]:
import gc
import torch
import os

def aggressive_cleanup():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        # Force garbage collection multiple times
        for _ in range(3):
            gc.collect()
            torch.cuda.empty_cache()

def check_gpu_memory():
    """Check GPU memory with cleanup"""
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        free = total - allocated
        print(f"GPU: {allocated:.1f}GB used, {free:.1f}GB free, {total:.1f}GB total")
        return free > 1.0  # Need at least 1GB free
    return False

# Set memory optimization environment variables BEFORE importing transformers
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256,expandable_segments:True'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Initial cleanup
aggressive_cleanup()
check_gpu_memory()

In [None]:
# Authentication (essential)
from getpass import getpass

hf_token = getpass("Hugging Face token: ")
os.environ['HF_TOKEN'] = hf_token

# Quick auth test
try:
    from huggingface_hub import HfApi
    api = HfApi()
    user = api.whoami(token=hf_token)
    print(f"✅ Authenticated as: {user['name']}")
except Exception as e:
    print(f"❌ Auth failed: {e}")
    
del api, user  # Cleanup immediately
aggressive_cleanup()

In [None]:
# Get the lightweight visualizer
!git clone -q https://github.com/javier-alvarez/maira-interp.git
%cd maira-interp

# Verify we have the lightweight version
!ls lightweight_visualizer.py
print("✅ Repository cloned")

In [None]:
# Create ultra-lightweight MAIRA-2 wrapper
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import requests
from io import BytesIO

class UltraLightMAIRA2:
    """Ultra-lightweight MAIRA-2 wrapper for memory-constrained environments"""
    
    def __init__(self):
        self.model = None
        self.processor = None
        
    def load_model_cautiously(self):
        """Load model with maximum memory optimizations"""
        print("🔄 Loading MAIRA-2 with extreme memory optimization...")
        
        # Check memory before loading
        if not check_gpu_memory():
            print("❌ Insufficient GPU memory")
            return False
            
        try:
            # Load with maximum memory savings
            print("Loading processor...")
            self.processor = AutoProcessor.from_pretrained(
                "microsoft/maira-2",
                trust_remote_code=True,
                token=os.environ.get('HF_TOKEN'),
                torch_dtype=torch.float16  # Half precision
            )
            
            aggressive_cleanup()
            
            print("Loading model...")
            self.model = AutoModelForCausalLM.from_pretrained(
                "microsoft/maira-2",
                trust_remote_code=True,
                token=os.environ.get('HF_TOKEN'),
                torch_dtype=torch.float16,  # Half precision
                low_cpu_mem_usage=True,     # Reduce CPU memory
                device_map="auto"           # Automatic device placement
            )
            
            # Additional memory optimization
            if hasattr(self.model, 'eval'):
                self.model.eval()  # Set to eval mode
                
            aggressive_cleanup()
            
            print("✅ Model loaded successfully!")
            check_gpu_memory()
            return True
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            self.cleanup()
            return False
    
    def generate_minimal_report(self, image_path_or_pil, prompt="What do you see?"):
        """Generate minimal report with maximum memory efficiency"""
        if self.model is None or self.processor is None:
            print("❌ Model not loaded")
            return None
            
        try:
            # Load image
            if isinstance(image_path_or_pil, str):
                if image_path_or_pil.startswith('http'):
                    response = requests.get(image_path_or_pil)
                    image = Image.open(BytesIO(response.content))
                else:
                    image = Image.open(image_path_or_pil)
            else:
                image = image_path_or_pil
                
            # Convert to grayscale to save memory
            if image.mode != 'L':
                image = image.convert('L')
            
            # Process inputs with minimal context
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]
            
            prompt_text = self.processor.apply_chat_template(
                conversation, add_generation_prompt=True
            )
            
            inputs = self.processor(
                text=prompt_text,
                images=[image],
                return_tensors="pt"
            )
            
            # Move to GPU
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            
            # Generate with minimal tokens
            with torch.no_grad():  # Save memory
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=15,  # Very minimal
                    do_sample=False,    # Deterministic
                    temperature=None,   # No sampling
                    top_p=None,
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )
            
            # Decode response
            prompt_len = inputs['input_ids'].shape[1]
            generated_tokens = output[0][prompt_len:]
            response = self.processor.tokenizer.decode(
                generated_tokens, skip_special_tokens=True
            )
            
            # Cleanup
            del inputs, output, generated_tokens
            aggressive_cleanup()
            
            return response.strip()
            
        except Exception as e:
            print(f"❌ Generation failed: {e}")
            aggressive_cleanup()
            return None
    
    def cleanup(self):
        """Clean up model from memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        aggressive_cleanup()

print("✅ Ultra-lightweight wrapper created")

In [None]:
# Initialize and load model
visualizer = UltraLightMAIRA2()

print("Starting model loading...")
print("This will take 5-10 minutes on first run (downloading ~13GB)")

success = visualizer.load_model_cautiously()

if success:
    print("🎉 MAIRA-2 loaded successfully!")
    print("Ready for minimal demonstration.")
else:
    print("❌ Failed to load MAIRA-2")
    print("Your GPU doesn't have enough memory.")
    print("Try Colab Pro or Kaggle Notebooks.")

In [None]:
# Test with minimal example
if visualizer.model is not None:
    print("🔍 Testing with sample chest X-ray...")
    
    # Use a small sample image URL
    sample_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    
    # Generate minimal report
    prompt = "Describe this chest X-ray briefly."
    
    print(f"Prompt: {prompt}")
    print("Generating...")
    
    result = visualizer.generate_minimal_report(sample_url, prompt)
    
    if result:
        print(f"\n📄 Generated Report:")
        print("=" * 40)
        print(result)
        print("=" * 40)
        print("\n✅ Minimal demo successful!")
    else:
        print("❌ Generation failed")
        
    check_gpu_memory()
else:
    print("❌ Model not loaded - cannot test")

In [None]:
# Optional: Try with your own image
if visualizer.model is not None:
    print("📸 You can now try with your own chest X-ray:")
    print("1. Upload an image file to Colab (click folder icon 🗂️)")
    print("2. Update the image_path below")
    print("3. Run this cell")
    
    # CHANGE THIS PATH to your uploaded image
    image_path = "your_image.png"  # Replace with your image filename
    
    # Check if image exists
    if os.path.exists(image_path):
        print(f"Processing {image_path}...")
        
        custom_prompt = "What abnormalities do you see in this chest X-ray?"
        result = visualizer.generate_minimal_report(image_path, custom_prompt)
        
        if result:
            print(f"\n📄 Report for {image_path}:")
            print("=" * 40)
            print(result)
            print("=" * 40)
        else:
            print("❌ Failed to process your image")
    else:
        print(f"⚠️  Image {image_path} not found.")
        print("Upload an image first, then update the image_path variable.")
else:
    print("❌ Model not loaded")

In [None]:
# Final cleanup
print("🧹 Cleaning up memory...")

if 'visualizer' in locals():
    visualizer.cleanup()
    del visualizer

aggressive_cleanup()
print("✅ Memory cleaned")
check_gpu_memory()

## 📋 Summary

**This lightweight notebook:**
- ✅ Works on 14.7GB GPUs (Colab T4)
- ✅ Uses aggressive memory optimization
- ✅ Generates minimal but functional reports
- ✅ Handles OOM gracefully

**Limitations:**
- Only generates ~15 tokens (vs 100+ in full version)
- No attention visualizations (would require too much memory)
- Single image only (no lateral/prior images)
- Half precision (might affect quality slightly)

**For full features:**
- Use Colab Pro (guaranteed A100/V100 with 40GB+)
- Try Kaggle Notebooks (sometimes more generous)
- Use Paperspace or other cloud platforms

**This demo proves MAIRA-2 can run on your hardware!** 🎉