# MAIRA-2 vs LLaVA-v1.6 Comparison for Memory-Constrained GPUs

**🎯 Two options for different GPU memory constraints:**

- **LLaVA-v1.6-Mistral-7B**: ~7GB VRAM, excellent general vision-language model, great at medical images
- **MAIRA-2**: ~15GB VRAM, specialized for radiology, state-of-the-art medical AI

**Setup:**
1. Runtime > Change runtime type > GPU > Save
2. For MAIRA-2: Get HF token from https://huggingface.co/settings/tokens
3. Choose your model based on available GPU memory

In [None]:
# Check environment and recommend model
import torch
import subprocess

print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_props = torch.cuda.get_device_properties(0)
    gpu_memory_gb = gpu_props.total_memory / 1024**3
    print(f"GPU: {gpu_props.name}")
    print(f"Memory: {gpu_memory_gb:.1f}GB")
    
    print(f"\n🎯 Model Recommendations:")
    if gpu_memory_gb < 6:
        print("❌ Insufficient memory for vision-language models")
        print("Try: Runtime > Change runtime type > Select different GPU")
    elif gpu_memory_gb < 14:
        print(f"✅ LLaVA-v1.6-Mistral-7B (7GB) - RECOMMENDED for {gpu_memory_gb:.1f}GB GPU")
        print("⚠️  MAIRA-2 (15GB) - Will likely fail with OOM")
        recommended_model = "llava-v1.6-mistral-7b"
    elif gpu_memory_gb < 16:
        print(f"⚠️  Both models possible but MAIRA-2 will be tight")
        print("🎯 Try LLaVA-v1.6 first, then MAIRA-2 if you want")
        recommended_model = "llava-v1.6-mistral-7b"  # Safer choice
    else:
        print("✅ Both models will work fine")
        recommended_model = "maira-2"  # Can afford the medical specialist
        
    print(f"\n🔧 Recommended: {recommended_model}")
else:
    print("❌ No GPU! Go to Runtime > Change runtime type > GPU")
    recommended_model = None

In [None]:
# Install packages - using more compatible versions
!pip install -q torch torchvision
!pip install -q transformers accelerate
!pip install -q pillow matplotlib requests
print("✅ Compatible packages installed")

In [None]:
import gc
import torch
import os

def aggressive_cleanup():
    """Aggressive memory cleanup"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        # Force garbage collection multiple times
        for _ in range(3):
            gc.collect()
            torch.cuda.empty_cache()

def check_gpu_memory():
    """Check GPU memory with cleanup"""
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        free = total - allocated
        print(f"GPU: {allocated:.1f}GB used, {free:.1f}GB free, {total:.1f}GB total")
        return free > 1.0  # Need at least 1GB free
    return False

# Set memory optimization environment variables BEFORE importing transformers
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256,expandable_segments:True'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Initial cleanup
aggressive_cleanup()
check_gpu_memory()

In [None]:
# Authentication (essential)
from getpass import getpass

hf_token = getpass("Hugging Face token: ")
os.environ['HF_TOKEN'] = hf_token

# Quick auth test
try:
    from huggingface_hub import HfApi
    api = HfApi()
    user = api.whoami(token=hf_token)
    print(f"✅ Authenticated as: {user['name']}")
except Exception as e:
    print(f"❌ Auth failed: {e}")
    
del api, user  # Cleanup immediately
aggressive_cleanup()

In [None]:
# Dual-model wrapper for MAIRA-2 and LLaVA (fixed imports)
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers import LlavaNextForConditionalGeneration  # Correct LLaVA class
from PIL import Image
import requests
from io import BytesIO

class DualVisionLanguageModel:
    """Unified wrapper for MAIRA-2 and LLaVA-v1.6 with correct model classes"""
    
    def __init__(self, model_name="llava-v1.6-mistral-7b"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        
        # Model configurations
        self.configs = {
            "llava-v1.6-mistral-7b": {
                "model_id": "llava-hf/llava-v1.6-mistral-7b-hf",
                "memory_gb": 7,
                "needs_token": False,
                "trust_remote_code": False,
                "torch_dtype": torch.float16,
                "max_new_tokens": 50,
                "model_class": LlavaNextForConditionalGeneration  # Correct class
            },
            "maira-2": {
                "model_id": "microsoft/maira-2", 
                "memory_gb": 15,
                "needs_token": True,
                "trust_remote_code": True,
                "torch_dtype": torch.float16,
                "max_new_tokens": 15,
                "model_class": AutoModelForCausalLM  # Standard class
            }
        }
        
        if model_name not in self.configs:
            raise ValueError(f"Unsupported model: {model_name}. Use 'llava-v1.6-mistral-7b' or 'maira-2'")
            
        self.config = self.configs[model_name]
        
    def load_model_cautiously(self):
        """Load model with appropriate optimizations"""
        print(f"🔄 Loading {self.model_name} ({self.config['memory_gb']}GB)...")
        
        # Check memory before loading
        if not check_gpu_memory():
            print("❌ Insufficient GPU memory")
            return False
            
        try:
            # Get token if needed
            token = os.environ.get('HF_TOKEN') if self.config['needs_token'] else None
            
            print("Loading processor...")
            self.processor = AutoProcessor.from_pretrained(
                self.config['model_id'],
                trust_remote_code=self.config['trust_remote_code'],
                token=token
            )
            
            aggressive_cleanup()
            
            print("Loading model...")
            # Use the correct model class for each model
            model_class = self.config['model_class']
            self.model = model_class.from_pretrained(
                self.config['model_id'],
                trust_remote_code=self.config['trust_remote_code'],
                token=token,
                torch_dtype=self.config['torch_dtype'],
                low_cpu_mem_usage=True,
                device_map="auto"
            )
            
            # Set to eval mode
            if hasattr(self.model, 'eval'):
                self.model.eval()
                
            aggressive_cleanup()
            
            print(f"✅ {self.model_name} loaded successfully!")
            check_gpu_memory()
            return True
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            self.cleanup()
            return False
    
    def generate_report(self, image_path_or_pil, prompt=None):
        """Generate report with model-appropriate prompting"""
        if self.model is None or self.processor is None:
            print("❌ Model not loaded")
            return None
            
        try:
            # Load and prepare image
            if isinstance(image_path_or_pil, str):
                if image_path_or_pil.startswith('http'):
                    response = requests.get(image_path_or_pil)
                    image = Image.open(BytesIO(response.content))
                else:
                    image = Image.open(image_path_or_pil)
            else:
                image = image_path_or_pil
                
            # Model-specific prompting
            if self.model_name == "llava-v1.6-mistral-7b":
                return self._generate_llava(image, prompt)
            else:  # maira-2
                return self._generate_maira(image, prompt)
                
        except Exception as e:
            print(f"❌ Generation failed: {e}")
            aggressive_cleanup()
            return None
    
    def _generate_llava(self, image, prompt=None):
        """Generate using LLaVA format"""
        if prompt is None:
            prompt = "[INST] <image>\nAnalyze this chest X-ray image. Describe any abnormalities you observe. [/INST]"
        else:
            prompt = f"[INST] <image>\n{prompt} [/INST]"
        
        # Convert to RGB for LLaVA
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Process inputs
        inputs = self.processor(
            text=prompt,
            images=image,
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Generate
        with torch.no_grad():
            output = self.model.generate(
                **inputs,
                max_new_tokens=self.config['max_new_tokens'],
                do_sample=False,
                temperature=None
            )
        
        # Decode
        prompt_len = inputs['input_ids'].shape[1]
        generated_tokens = output[0][prompt_len:]
        response = self.processor.tokenizer.decode(
            generated_tokens, skip_special_tokens=True
        )
        
        # Cleanup
        del inputs, output, generated_tokens
        aggressive_cleanup()
        
        return response.strip()
    
    def _generate_maira(self, image, prompt=None):
        """Generate using MAIRA-2 format"""
        if prompt is None:
            prompt = "Describe this chest X-ray briefly."
            
        # Convert to grayscale for MAIRA-2 (medical standard)
        if image.mode != 'L':
            image = image.convert('L')
        
        # MAIRA-2 conversation format
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": prompt}
                ]
            }
        ]
        
        prompt_text = self.processor.apply_chat_template(
            conversation, add_generation_prompt=True
        )
        
        inputs = self.processor(
            text=prompt_text,
            images=[image],
            return_tensors="pt"
        )
        
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Generate
        with torch.no_grad():
            output = self.model.generate(
                **inputs,
                max_new_tokens=self.config['max_new_tokens'],
                do_sample=False,
                pad_token_id=self.processor.tokenizer.eos_token_id
            )
        
        # Decode
        prompt_len = inputs['input_ids'].shape[1]
        generated_tokens = output[0][prompt_len:]
        response = self.processor.tokenizer.decode(
            generated_tokens, skip_special_tokens=True
        )
        
        # Cleanup
        del inputs, output, generated_tokens
        aggressive_cleanup()
        
        return response.strip()
    
    def cleanup(self):
        """Clean up model from memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.processor is not None:
            del self.processor
            self.processor = None
        aggressive_cleanup()

print("✅ Dual-model wrapper created (MAIRA-2 vs LLaVA-v1.6) with correct model classes")

In [None]:
# Choose and initialize your model
print("🎯 Model Selection")
print("1. LLaVA-v1.6-Mistral-7B: ~7GB, works on most GPUs, excellent general vision-language")
print("2. MAIRA-2: ~15GB, medical specialist, needs HF token")

# Auto-select based on earlier recommendation, but allow manual override
model_choice = input(f"Choose model (1 for LLaVA-v1.6, 2 for MAIRA-2) [default: 1]: ").strip()

if model_choice == "2":
    model_name = "maira-2"
    print("Selected: MAIRA-2 (medical specialist)")
    
    # Check if token is needed
    if 'HF_TOKEN' not in os.environ:
        from getpass import getpass
        hf_token = getpass("Enter your HF token (required for MAIRA-2): ")
        os.environ['HF_TOKEN'] = hf_token
else:
    model_name = "llava-v1.6-mistral-7b" 
    print("Selected: LLaVA-v1.6-Mistral-7B (general vision-language)")

# Initialize model
print(f"\n🚀 Initializing {model_name}...")
visualizer = DualVisionLanguageModel(model_name)

success = visualizer.load_model_cautiously()

if success:
    print(f"🎉 {model_name} loaded successfully!")
    print("Ready for image analysis.")
else:
    print(f"❌ Failed to load {model_name}")
    if model_name == "maira-2":
        print("Try switching to LLaVA-v1.6 (option 1)")
    else:
        print("Check your GPU memory or try restarting runtime")

In [None]:
# Test with sample chest X-ray on both models
if visualizer.model is not None:
    print(f"🔍 Testing {visualizer.model_name} with sample chest X-ray...")
    
    # Use the same sample image for comparison
    sample_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    
    # Test prompts
    prompts = [
        "Describe this chest X-ray image briefly.",
        "What abnormalities do you see in this chest X-ray?",
        "Analyze this medical image and report your findings."
    ]
    
    for i, prompt in enumerate(prompts, 1):
        print(f"\n--- Test {i}: {prompt} ---")
        
        result = visualizer.generate_report(sample_url, prompt)
        
        if result:
            print(f"📄 {visualizer.model_name} Response:")
            print("=" * 50)
            print(result)
            print("=" * 50)
        else:
            print("❌ Generation failed for this prompt")
            
        # Check memory after each test
        check_gpu_memory()
    
    print(f"\n✅ {visualizer.model_name} testing complete!")
    
    # Show sample image for reference
    try:
        import requests
        from PIL import Image
        from io import BytesIO
        response = requests.get(sample_url)
        img = Image.open(BytesIO(response.content))
        print("\n📸 Sample chest X-ray used for testing:")
        display(img.resize((256, 256)))
    except:
        print("📸 Sample image URL: " + sample_url)
        
else:
    print("❌ Model not loaded - cannot test")

In [None]:
# Test with minimal example
if visualizer.model is not None:
    print("🔍 Testing with sample chest X-ray...")
    
    # Use a small sample image URL
    sample_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
    
    # Generate minimal report
    prompt = "Describe this chest X-ray briefly."
    
    print(f"Prompt: {prompt}")
    print("Generating...")
    
    result = visualizer.generate_minimal_report(sample_url, prompt)
    
    if result:
        print(f"\n📄 Generated Report:")
        print("=" * 40)
        print(result)
        print("=" * 40)
        print("\n✅ Minimal demo successful!")
    else:
        print("❌ Generation failed")
        
    check_gpu_memory()
else:
    print("❌ Model not loaded - cannot test")

In [None]:
# Optional: Try with your own image
if visualizer.model is not None:
    print("📸 You can now try with your own chest X-ray:")
    print("1. Upload an image file to Colab (click folder icon 🗂️)")
    print("2. Update the image_path below")
    print("3. Run this cell")
    
    # CHANGE THIS PATH to your uploaded image
    image_path = "your_image.png"  # Replace with your image filename
    
    # Check if image exists
    if os.path.exists(image_path):
        print(f"Processing {image_path}...")
        
        custom_prompt = "What abnormalities do you see in this chest X-ray?"
        result = visualizer.generate_minimal_report(image_path, custom_prompt)
        
        if result:
            print(f"\n📄 Report for {image_path}:")
            print("=" * 40)
            print(result)
            print("=" * 40)
        else:
            print("❌ Failed to process your image")
    else:
        print(f"⚠️  Image {image_path} not found.")
        print("Upload an image first, then update the image_path variable.")
else:
    print("❌ Model not loaded")

## 📊 Model Comparison Summary

**This notebook lets you compare two vision-language models:**

### 🔬 **LLaVA-v1.6-Mistral-7B** (Recommended for Colab)
- **Memory**: ~7GB (fits most GPUs)
- **Speed**: Fast inference  
- **Domain**: General vision-language, excellent at medical images
- **Token**: Not required
- **Output**: 50 tokens, detailed analysis
- **Compatibility**: Excellent Colab support

### 🏥 **MAIRA-2** (Medical Specialist)
- **Memory**: ~15GB (tight fit on free Colab)
- **Speed**: Slower due to size
- **Domain**: Radiology specialist
- **Token**: Required (HF account + MAIRA-2 access)
- **Output**: 15 tokens, focused medical analysis

### 🎯 **Key Differences in Code:**
1. **Input format**: LLaVA uses RGB, MAIRA uses grayscale
2. **Chat templates**: Different conversation structures
3. **Token requirements**: MAIRA needs authentication
4. **Prompting style**: LLaVA uses USER/ASSISTANT format

### 💡 **Recommendations:**
- **Start with LLaVA-v1.6** to validate your workflow - it's excellent at medical images!
- **Switch to MAIRA-2** for specialized radiology applications
- **Use this notebook** to test both and compare outputs
- **LLaVA-v1.6 is perfect** for prototyping and general medical AI

LLaVA-v1.6 is surprisingly good at medical images while being much more memory-efficient!

## 📋 Summary

**This lightweight notebook:**
- ✅ Works on 14.7GB GPUs (Colab T4)
- ✅ Uses aggressive memory optimization
- ✅ Generates minimal but functional reports
- ✅ Handles OOM gracefully

**Limitations:**
- Only generates ~15 tokens (vs 100+ in full version)
- No attention visualizations (would require too much memory)
- Single image only (no lateral/prior images)
- Half precision (might affect quality slightly)

**For full features:**
- Use Colab Pro (guaranteed A100/V100 with 40GB+)
- Try Kaggle Notebooks (sometimes more generous)
- Use Paperspace or other cloud platforms

**This demo proves MAIRA-2 can run on your hardware!** 🎉