# ü©ª ExplainMyXray - Advanced MedGemma Training

## Features:
- **Large-Scale Datasets**: NIH ChestX-ray14 (112K), CheXpert (224K), MIMIC-CXR (377K), PadChest (160K)
- **Gemma 3 Architecture**: PaliGemma2 with 4-bit QLoRA fine-tuning
- **Advanced Training**: 15+ epochs, cosine LR scheduler, early stopping, gradient clipping
- **Comprehensive Evaluation**: BLEU, ROUGE, BERTScore, clinical accuracy metrics
- **Proper Data Splits**: Train (80%) / Validation (10%) / Test (10%)

## üì¶ Cell 1: Install Dependencies

**‚ö†Ô∏è IMPORTANT WORKFLOW:**
1. Run Cell 1 (Install) ‚Üí Wait for completion
2. Click "RESTART SESSION" button that appears  
3. Skip Cell 1, run Cell 1b (Verify) onwards

In [None]:
# ============================================================
# COLAB T4 OPTIMIZED - Install Dependencies (Jan 2026)
# ============================================================
# IMPORTANT: 
# 1. Run this cell FIRST
# 2. When it completes, click "RESTART SESSION" button that appears
# 3. Then run Cell 2 onwards (skip re-running this cell)

import os

# ============================================================
# STEP 1: Uninstall conflicting pre-installed packages
# ============================================================
print("üßπ Cleaning up conflicting packages...")
!pip uninstall -y transformers peft accelerate bitsandbytes -q 2>/dev/null || true

# ============================================================
# STEP 2: Install compatible versions (numpy 2.x compatible)
# ============================================================
print("\nüì¶ Installing ML libraries...")

# Core ML libraries - use versions compatible with Colab's numpy 2.x
!pip install -q --upgrade pip setuptools wheel

# Install transformers and related packages (numpy 2.x compatible)
!pip install -q "transformers>=4.47.0"
!pip install -q "peft>=0.14.0"
!pip install -q "accelerate>=1.0.0"
!pip install -q "bitsandbytes>=0.45.0"
!pip install -q "datasets>=3.0.0"

# Data & visualization (don't constrain numpy - use Colab's version)
!pip install -q "pillow>=10.0.0" scikit-learn matplotlib seaborn

# Evaluation metrics
!pip install -q evaluate sacrebleu rouge_score nltk

# Dataset download utilities  
!pip install -q kaggle gdown opendatasets tqdm

# Fix jedi for IPython (optional warning fix)
!pip install -q jedi

print("\n" + "="*60)
print("‚úÖ INSTALLATION COMPLETE!")
print("="*60)
print("")
print("‚ö†Ô∏è  IMPORTANT: You MUST restart the runtime now!")
print("")
print("   Click: Runtime ‚Üí Restart session")
print("   OR click the 'RESTART SESSION' button above")
print("")
print("   After restart, skip this cell and run Cell 2 onwards.")
print("="*60)

In [None]:
# ============================================================
# Cell 1b: VERIFY INSTALLATION (Run after restart)
# ============================================================
# Run this cell AFTER restarting the runtime

import torch
import transformers
import peft
import accelerate
import numpy as np

print("="*60)
print("üîß SYSTEM INFO - COLAB T4")
print("="*60)
print(f"PyTorch:       {torch.__version__}")
print(f"Transformers:  {transformers.__version__}")
print(f"PEFT:          {peft.__version__}")
print(f"Accelerate:    {accelerate.__version__}")
print(f"NumPy:         {np.__version__}")

print(f"\nCUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU:           {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM:          {gpu_mem:.1f} GB")
    if "T4" in torch.cuda.get_device_name(0):
        print("‚úÖ T4 GPU detected - config optimized for 16GB VRAM")
    else:
        print(f"‚ö†Ô∏è  Non-T4 GPU detected - may need batch_size adjustment")

# Check bitsandbytes
try:
    import bitsandbytes as bnb
    print(f"Bitsandbytes:  {bnb.__version__}")
    print("‚úÖ All imports successful!")
except Exception as e:
    print(f"‚ö†Ô∏è  Bitsandbytes issue: {e}")

print("="*60)

## üîê Cell 2: Authentication (Hugging Face + Kaggle)

In [None]:
# ============================================================
# Cell 6: HUGGING FACE + KAGGLE AUTHENTICATION
# ============================================================

from huggingface_hub import login
import os
import shutil
from pathlib import Path

# ============================================================
# HUGGING FACE LOGIN
# ============================================================
HF_TOKEN = "REDACTED_TOKEN_USE_ENV_VAR"
login(token=HF_TOKEN)
print("‚úÖ Logged in to Hugging Face!")

# ============================================================
# KAGGLE AUTHENTICATION - Multiple fallback methods
# ============================================================
kaggle_ready = False

# Method 1: Check if kaggle.json exists in /content/ (uploaded file)
content_kaggle = Path("/content/kaggle.json")
if content_kaggle.exists():
    # Copy to ~/.kaggle/
    kaggle_dir = Path.home() / ".kaggle"
    kaggle_dir.mkdir(exist_ok=True)
    shutil.copy(content_kaggle, kaggle_dir / "kaggle.json")
    os.chmod(kaggle_dir / "kaggle.json", 0o600)
    print("‚úÖ Kaggle credentials copied from /content/kaggle.json to ~/.kaggle/")
    kaggle_ready = True

# Method 2: Check ~/.kaggle/kaggle.json
if not kaggle_ready:
    kaggle_path = Path.home() / ".kaggle" / "kaggle.json"
    if kaggle_path.exists():
        print("‚úÖ Kaggle credentials found in ~/.kaggle/kaggle.json")
        kaggle_ready = True

# Method 3: Try Colab secrets
if not kaggle_ready:
    try:
        from google.colab import userdata
        username = userdata.get('KAGGLE_USERNAME')
        key = userdata.get('KAGGLE_KEY')
        if username and key:
            os.environ['KAGGLE_USERNAME'] = username
            os.environ['KAGGLE_KEY'] = key
            print("‚úÖ Kaggle credentials loaded from Colab secrets!")
            kaggle_ready = True
    except:
        pass

# Method 4: Check environment variables
if not kaggle_ready:
    if os.environ.get('KAGGLE_USERNAME') and os.environ.get('KAGGLE_KEY'):
        print("‚úÖ Kaggle credentials found in environment variables!")
        kaggle_ready = True

if not kaggle_ready:
    print("‚ö†Ô∏è Kaggle credentials not found!")
    print("   Please upload kaggle.json to /content/ or set Colab secrets")

print("\n‚úÖ Authentication complete!")

In [None]:
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import random
import numpy as np
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")

class ChestXrayDataset(Dataset):
    """Dataset for chest X-ray training with augmentation - PaliGemma optimized"""
    
    def __init__(
        self,
        df: pd.DataFrame,
        processor,
        max_length: int = 384,
        is_train: bool = True,
    ):
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.max_length = max_length
        self.is_train = is_train
        
        # Medical-appropriate augmentation for training
        if is_train:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.RandomRotation(degrees=5),
                transforms.ColorJitter(brightness=0.1, contrast=0.1),
            ])
        else:
            self.transform = None
        
        # Cache failed image paths to avoid repeated errors
        self.failed_images = set()
        
        print(f"{'Train' if is_train else 'Eval'} dataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        
        # Load image with caching of failures
        img_path = row["ImagePath"]
        try:
            if img_path not in self.failed_images:
                image = Image.open(img_path).convert("RGB")
            else:
                image = Image.new("RGB", (224, 224), color=(128, 128, 128))
        except Exception as e:
            self.failed_images.add(img_path)
            image = Image.new("RGB", (224, 224), color=(128, 128, 128))
        
        # Apply augmentation
        if self.transform:
            image = self.transform(image)
        
        # Get report text
        text = str(row.get("Report", "")).strip()
        if not text or text == "nan":
            text = "Normal chest radiograph."
        
        # Create varied prompts for better generalization
        # Note: PaliGemma processor automatically handles <image> token when images are passed
        prompts = [
            "describe this chest xray",
            "explain this chest radiograph", 
            "analyze this xray image",
            "what does this chest xray show",
            "interpret this chest radiograph",
        ]
        prompt = random.choice(prompts) if self.is_train else "describe this chest xray"
        
        # Process with PaliGemma processor
        # The processor handles image token insertion automatically
        try:
            inputs = self.processor(
                text=prompt,
                images=image,
                suffix=text,  # This is the target output (report)
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
            )
        except Exception as e:
            # Fallback without suffix parameter
            full_text = f"{prompt}: {text}"
            inputs = self.processor(
                text=full_text,
                images=image,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
            )
        
        return {k: v.squeeze(0) for k, v in inputs.items()}

print("‚úÖ ChestXrayDataset class defined (optimized for PaliGemma)")

## üìä Cell 3: Download Kaggle Datasets (AUTOMATIC)

### Datasets Used:
1. **Chest X-Ray Pneumonia** (~5,863 images) - `paultimothymooney/chest-xray-pneumonia`
2. **NIH Chest X-ray Sample** (~5,606 images) - `nih-chest-xrays/sample`
3. **COVID-19 Radiography** (~21,165 images) - `tawsifurrahman/covid19-radiography-database`

**No manual upload required!** Just run the cell below.

In [None]:
import os
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import shutil
import json

# ============================================================
# GOOGLE DRIVE SETUP - PERSISTENT STORAGE
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

# Store datasets in Google Drive (persists across sessions!)
DRIVE_DATA_ROOT = Path("/content/drive/MyDrive/ExplainMyXray_Datasets")
DRIVE_DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Local cache for faster training
DATA_ROOT = Path("/content/datasets")
DATA_ROOT.mkdir(exist_ok=True)

print(f"üìÅ Drive storage: {DRIVE_DATA_ROOT}")
print(f"üìÅ Local cache: {DATA_ROOT}")

# Define directories
PNEUMONIA_DIR = DATA_ROOT / "chest_xray_pneumonia"
NIH_SAMPLE_DIR = DATA_ROOT / "nih_sample"

drive_pneumonia = DRIVE_DATA_ROOT / "chest_xray_pneumonia"
drive_nih = DRIVE_DATA_ROOT / "nih_sample"

# ============================================================
# CHECK IF DATA ALREADY EXISTS IN DRIVE
# ============================================================
pneumonia_cached = drive_pneumonia.exists() and any(drive_pneumonia.rglob("*.jpeg"))
nih_cached = drive_nih.exists() and any(drive_nih.rglob("*.png"))

need_kaggle = not (pneumonia_cached and nih_cached)

if pneumonia_cached and nih_cached:
    print("\n‚úÖ Both datasets found in Google Drive! Loading from cache...")
    print("   (No Kaggle credentials needed)")
    kaggle_configured = True
    
    # Copy from Drive to local
    if not PNEUMONIA_DIR.exists():
        print(f"   üìÇ Loading chest_xray_pneumonia...")
        shutil.copytree(drive_pneumonia, PNEUMONIA_DIR)
    
    if not NIH_SAMPLE_DIR.exists():
        print(f"   üìÇ Loading nih_sample...")
        shutil.copytree(drive_nih, NIH_SAMPLE_DIR)

else:
    # ============================================================
    # KAGGLE API SETUP (only if data not fully cached)
    # ============================================================
    print("\nüîê Setting up Kaggle API...")
    
    KAGGLE_USERNAME = ""  # ‚Üê Enter your Kaggle username here
    KAGGLE_KEY = ""       # ‚Üê Enter your Kaggle API key here
    
    kaggle_configured = False
    
    if KAGGLE_USERNAME and KAGGLE_KEY:
        kaggle_dir = Path.home() / ".kaggle"
        kaggle_dir.mkdir(exist_ok=True)
        kaggle_json = kaggle_dir / "kaggle.json"
        
        with open(kaggle_json, "w") as f:
            json.dump({"username": KAGGLE_USERNAME, "key": KAGGLE_KEY}, f)
        os.chmod(kaggle_json, 0o600)
        print("‚úÖ Kaggle configured with direct credentials!")
        kaggle_configured = True
    else:
        from google.colab import files
        kaggle_dir = Path.home() / ".kaggle"
        kaggle_dir.mkdir(exist_ok=True)
        kaggle_json = kaggle_dir / "kaggle.json"
        
        if not kaggle_json.exists():
            print("üì§ Upload your kaggle.json file:")
            try:
                uploaded = files.upload()
                if uploaded:
                    filename = list(uploaded.keys())[0]
                    with open(kaggle_json, "wb") as f:
                        f.write(uploaded[filename])
                    os.chmod(kaggle_json, 0o600)
                    print("‚úÖ Kaggle API configured!")
                    kaggle_configured = True
            except:
                print("‚ö†Ô∏è kaggle.json not uploaded")
        else:
            print("‚úÖ Kaggle already configured!")
            kaggle_configured = True
    
    # ============================================================
    # DOWNLOAD DATASETS
    # ============================================================
    if kaggle_configured:
        
        # DATASET 1: Chest X-Ray Pneumonia
        print("\nüì• Dataset 1: Chest X-Ray Pneumonia")
        if pneumonia_cached:
            print("   ‚úÖ Found in Drive cache, copying to local...")
            if not PNEUMONIA_DIR.exists():
                shutil.copytree(drive_pneumonia, PNEUMONIA_DIR)
        else:
            print("   üì• Downloading from Kaggle...")
            !kaggle datasets download -d paultimothymooney/chest-xray-pneumonia -p {DATA_ROOT} --unzip -q
            # Rename extracted folder
            if (DATA_ROOT / "chest_xray").exists():
                if PNEUMONIA_DIR.exists():
                    shutil.rmtree(PNEUMONIA_DIR)
                shutil.move(str(DATA_ROOT / "chest_xray"), str(PNEUMONIA_DIR))
            # Cache to Drive
            if PNEUMONIA_DIR.exists():
                print("   üíæ Caching to Drive...")
                if drive_pneumonia.exists():
                    shutil.rmtree(drive_pneumonia)
                shutil.copytree(PNEUMONIA_DIR, drive_pneumonia)
        
        pneumonia_count = sum(1 for _ in PNEUMONIA_DIR.rglob("*.jpeg")) if PNEUMONIA_DIR.exists() else 0
        print(f"   ‚úÖ Chest X-Ray Pneumonia: {pneumonia_count:,} images")

        # DATASET 2: NIH Chest X-ray Sample  
        # Download DIRECTLY into nih_sample folder (like original code)
        print("\nüì• Dataset 2: NIH Chest X-ray Sample")
        if nih_cached:
            print("   ‚úÖ Found in Drive cache, copying to local...")
            if not NIH_SAMPLE_DIR.exists():
                shutil.copytree(drive_nih, NIH_SAMPLE_DIR)
        else:
            print("   üì• Downloading from Kaggle...")
            # Download directly into nih_sample folder (this is what worked before!)
            NIH_SAMPLE_DIR.mkdir(exist_ok=True)
            !kaggle datasets download -d nih-chest-xrays/sample -p {NIH_SAMPLE_DIR} --unzip -q
            
            # Cache to Drive
            if NIH_SAMPLE_DIR.exists() and any(NIH_SAMPLE_DIR.rglob("*.png")):
                print("   üíæ Caching to Drive...")
                if drive_nih.exists():
                    shutil.rmtree(drive_nih)
                shutil.copytree(NIH_SAMPLE_DIR, drive_nih)

# ============================================================
# VERIFY NIH DATASET STRUCTURE
# ============================================================
nih_csv = None
nih_images = None

if NIH_SAMPLE_DIR.exists():
    # Search for sample_labels.csv
    for csv_path in NIH_SAMPLE_DIR.rglob("sample_labels.csv"):
        nih_csv = csv_path
        break
    
    # Search for images directory
    for img_dir in NIH_SAMPLE_DIR.rglob("images"):
        if img_dir.is_dir() and any(img_dir.glob("*.png")):
            nih_images = img_dir
            break

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "="*60)
print("üìä DATASET SUMMARY")
print("="*60)

# Pneumonia
pneumonia_count = sum(1 for _ in PNEUMONIA_DIR.rglob("*.jpeg")) if PNEUMONIA_DIR.exists() else 0
if pneumonia_count > 0:
    print(f"   ‚úÖ Chest X-Ray Pneumonia: {pneumonia_count:,} images")
else:
    print(f"   ‚ö†Ô∏è Chest X-Ray Pneumonia: not found")

# NIH
nih_count = sum(1 for _ in NIH_SAMPLE_DIR.rglob("*.png")) if NIH_SAMPLE_DIR.exists() else 0
if nih_count > 0:
    print(f"   ‚úÖ NIH Chest X-ray Sample: {nih_count:,} images")
    if nih_csv:
        df_nih_temp = pd.read_csv(nih_csv)
        print(f"      Labels CSV found: {len(df_nih_temp):,} entries")
    else:
        print(f"      ‚ö†Ô∏è Labels CSV not found")
        # Debug
        print(f"      Directory contents:")
        for item in list(NIH_SAMPLE_DIR.iterdir())[:5]:
            print(f"         - {item.name}")
else:
    print(f"   ‚ö†Ô∏è NIH Chest X-ray Sample: not found")
    if NIH_SAMPLE_DIR.exists():
        print(f"      Directory contents:")
        for item in list(NIH_SAMPLE_DIR.iterdir())[:5]:
            print(f"         - {item.name}")

total_images = pneumonia_count + nih_count
if total_images == 0:
    print("\n‚ö†Ô∏è No datasets downloaded!")
    print("   Please enter your Kaggle credentials above and re-run this cell.")
else:
    print(f"\n   TOTAL: {total_images:,} images available!")
    print(f"\nüíæ Data location: {DRIVE_DATA_ROOT}")
    print(f"   Next time, data loads from Drive (no re-download)")
print("="*60)

## üîß Cell 4: Advanced Configuration

In [None]:
import torch
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    """Training configuration OPTIMIZED for Colab T4 (16GB VRAM)"""
    
    # Model Selection - PaliGemma (Gemma 3 based VLM)
    model_id: str = "google/paligemma-3b-pt-224"  # Works on T4!
    
    # Output
    output_dir: str = "/content/medgemma_advanced_lora"
    
    # ============================================================
    # T4 OPTIMIZED Training Hyperparameters
    # ============================================================
    batch_size: int = 2  # T4 safe - increase to 4 if no OOM
    gradient_accumulation_steps: int = 16  # Effective batch = 32
    num_epochs: int = 10  # Good balance for T4 training time
    warmup_ratio: float = 0.1
    learning_rate: float = 2e-4
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    
    # LoRA Configuration (T4 optimized - lower rank to save VRAM)
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    
    # Sequence Configuration
    max_length: int = 384  # Reduced for T4 VRAM
    
    # Early Stopping
    early_stopping_patience: int = 3
    early_stopping_threshold: float = 0.001
    
    # Evaluation
    eval_steps: int = 100
    save_steps: int = 200
    logging_steps: int = 25

config = TrainingConfig()

# 4-bit Quantization Config (QLoRA) - T4 uses float16!
BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,  # T4 uses float16, NOT bfloat16!
    bnb_4bit_use_double_quant=True,
)

# LoRA Configuration (T4 optimized)
LORA_CONFIG = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Just attention
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

print("‚úÖ Configuration loaded (T4 OPTIMIZED):")
print(f"   Model: {config.model_id}")
print(f"   Epochs: {config.num_epochs}")
print(f"   Effective Batch Size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"   LoRA Rank: {config.lora_r}")
print(f"   Learning Rate: {config.learning_rate}")
print(f"   Compute Dtype: float16 (T4 compatible)")

## ü§ñ Cell 5: Load Gemma 3 Model (PaliGemma2)

In [None]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import get_peft_model, prepare_model_for_kbit_training
import gc

# Clear GPU memory before loading
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

print("üì¶ Loading PaliGemma (Gemma 3 Vision-Language Model)...")
print("This may take 3-5 minutes on T4...\n")

# ============================================================
# Load Processor
# ============================================================
try:
    processor = AutoProcessor.from_pretrained(
        config.model_id,
        token=HF_TOKEN,
    )
    print("‚úÖ Processor loaded")
except Exception as e:
    print(f"‚ö†Ô∏è Processor loading error: {e}")
    print("Trying without token...")
    processor = AutoProcessor.from_pretrained(config.model_id)

# ============================================================
# Load Model with 4-bit Quantization (QLoRA)
# ============================================================
try:
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        config.model_id,
        quantization_config=BNB_CONFIG,
        device_map="auto",
        torch_dtype=torch.float16,  # T4 uses float16!
        token=HF_TOKEN,
        low_cpu_mem_usage=True,
        attn_implementation="eager",  # For compatibility
    )
    print("‚úÖ Base model loaded in 4-bit")
except Exception as e:
    print(f"‚ö†Ô∏è Model loading error: {e}")
    # Fallback without attn_implementation
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        config.model_id,
        quantization_config=BNB_CONFIG,
        device_map="auto",
        torch_dtype=torch.float16,
        token=HF_TOKEN,
        low_cpu_mem_usage=True,
    )
    print("‚úÖ Base model loaded (fallback)")

# ============================================================
# Prepare for QLoRA Training
# ============================================================
# Handle gradient checkpointing for different PEFT versions
try:
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
    )
except TypeError:
    # Older PEFT version without gradient_checkpointing_kwargs
    model = prepare_model_for_kbit_training(
        model,
        use_gradient_checkpointing=True,
    )
print("‚úÖ Model prepared for k-bit training")

# Apply LoRA adapters
model = get_peft_model(model, LORA_CONFIG)

# Enable gradient checkpointing for memory efficiency
try:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
    model.gradient_checkpointing_enable()

print("\n" + "="*60)
print("üìä MODEL SUMMARY")
print("="*60)
model.print_trainable_parameters()

# Memory usage
if torch.cuda.is_available():
    used = torch.cuda.memory_allocated()/1e9
    total = torch.cuda.get_device_properties(0).total_memory/1e9
    print(f"\nGPU Memory: {used:.2f} GB / {total:.1f} GB ({used/total*100:.1f}%)")
    
    if used > 14:
        print("‚ö†Ô∏è  High VRAM usage - consider reducing batch_size to 1")
    else:
        print("‚úÖ VRAM usage OK for training")

## üìÅ Cell 6: Prepare Dataset from Kaggle Downloads

Uses the automatically downloaded Kaggle datasets. No manual upload needed!

In [None]:
import pandas as pd
from pathlib import Path
import os

# ============================================================
# PREPARE COMBINED DATASET FROM KAGGLE DOWNLOADS
# ============================================================
DATA_ROOT = Path("/content/datasets")

# Create unified dataframe from all downloaded datasets
all_data = []

# ============================================================
# 1. NIH Sample Dataset (has labels in CSV)
# ============================================================
NIH_SAMPLE_DIR = DATA_ROOT / "nih_sample"

# Search for sample_labels.csv in multiple possible locations
nih_csv = None
possible_csv_paths = [
    NIH_SAMPLE_DIR / "sample_labels.csv",
    NIH_SAMPLE_DIR / "sample" / "sample_labels.csv",
]
# Also search recursively
possible_csv_paths.extend(list(NIH_SAMPLE_DIR.rglob("sample_labels.csv")))

for path in possible_csv_paths:
    if path.exists():
        nih_csv = path
        print(f"‚úÖ Found NIH labels CSV: {nih_csv}")
        break

if nih_csv and nih_csv.exists():
    df_nih = pd.read_csv(nih_csv)
    
    # Find image directory (search for images folder)
    nih_images = None
    possible_img_dirs = [
        NIH_SAMPLE_DIR / "sample" / "images",
        NIH_SAMPLE_DIR / "images", 
        NIH_SAMPLE_DIR,
    ]
    possible_img_dirs.extend(list(NIH_SAMPLE_DIR.rglob("images")))
    
    for img_dir in possible_img_dirs:
        if img_dir.exists() and img_dir.is_dir() and any(img_dir.glob("*.png")):
            nih_images = img_dir
            print(f"‚úÖ Found NIH images: {nih_images}")
            break
    
    if nih_images:
        for _, row in df_nih.iterrows():
            img_name = row.get("Image Index", row.get("Image_Index", ""))
            labels = row.get("Finding Labels", row.get("Finding_Labels", "No Finding"))
            
            # Create patient-friendly description from labels
            if labels == "No Finding":
                report = "This chest X-ray appears normal with no significant findings."
            else:
                conditions = labels.replace("|", ", ")
                report = f"This chest X-ray shows signs of: {conditions}. Please consult your doctor for detailed interpretation."
            
            all_data.append({
                "ImageID": str(img_name),
                "ImagePath": str(nih_images / img_name),
                "Report": report,
                "Labels": labels,
                "Source": "NIH"
            })
        print(f"‚úÖ NIH Sample: {len(df_nih)} images loaded")
    else:
        print("‚ö†Ô∏è NIH images directory not found")
else:
    print("‚ö†Ô∏è NIH sample_labels.csv not found")
    # Debug: show what's in the directory
    if NIH_SAMPLE_DIR.exists():
        print(f"   Contents of {NIH_SAMPLE_DIR}:")
        for item in list(NIH_SAMPLE_DIR.iterdir())[:10]:
            print(f"      - {item.name}")

# ============================================================
# 2. Pneumonia Dataset (folder structure = labels)
# ============================================================
PNEUMONIA_DIR = DATA_ROOT / "chest_xray_pneumonia"
if not PNEUMONIA_DIR.exists():
    PNEUMONIA_DIR = DATA_ROOT / "chest_xray"
    
for split in ["train", "test", "val"]:
    split_dir = PNEUMONIA_DIR / split
    if split_dir.exists():
        for label_dir in split_dir.iterdir():
            if label_dir.is_dir():
                label = label_dir.name  # NORMAL or PNEUMONIA
                for img_path in label_dir.glob("*.jpeg"):
                    if label.upper() == "NORMAL":
                        report = "This chest X-ray appears normal with clear lung fields and no signs of pneumonia."
                    else:
                        report = "This chest X-ray shows signs consistent with pneumonia. The lung fields show areas of opacity that may indicate infection."
                    
                    all_data.append({
                        "ImageID": img_path.name,
                        "ImagePath": str(img_path),
                        "Report": report,
                        "Labels": label,
                        "Source": "Pneumonia"
                    })

print(f"‚úÖ Pneumonia Dataset: {len([d for d in all_data if d['Source'] == 'Pneumonia'])} images loaded")

# ============================================================
# 3. COVID-19 Radiography (optional - folder structure)
# ============================================================
COVID_DIR = DATA_ROOT / "covid19_radiography"
if COVID_DIR.exists():
    for category_dir in COVID_DIR.rglob("*"):
        if category_dir.is_dir() and category_dir.name in ["COVID", "Normal", "Viral Pneumonia", "Lung_Opacity"]:
            label = category_dir.name
            report_map = {
                "COVID": "This chest X-ray shows patterns that may be associated with COVID-19 infection, including ground-glass opacities.",
                "Normal": "This chest X-ray appears normal with no significant abnormalities detected.",
                "Viral Pneumonia": "This chest X-ray shows signs consistent with viral pneumonia.",
                "Lung_Opacity": "This chest X-ray shows areas of lung opacity that may require further evaluation."
            }
            for img_path in list(category_dir.glob("*.png"))[:500]:
                all_data.append({
                    "ImageID": img_path.name,
                    "ImagePath": str(img_path),
                    "Report": report_map.get(label, f"Finding: {label}"),
                    "Labels": label,
                    "Source": "COVID19"
                })
    print(f"‚úÖ COVID-19 Dataset: {len([d for d in all_data if d['Source'] == 'COVID19'])} images loaded")
else:
    print("‚ÑπÔ∏è COVID-19 Dataset: not downloaded (optional)")

# ============================================================
# Create final dataframe
# ============================================================
df_combined = pd.DataFrame(all_data)
df_combined = df_combined.drop_duplicates(subset=["ImageID"]).reset_index(drop=True)

# Filter to only existing images
print(f"\nüîç Verifying image paths...")
df_combined = df_combined[df_combined["ImagePath"].apply(lambda x: Path(x).exists())]
df_combined = df_combined.reset_index(drop=True)

print(f"\n{'='*60}")
print(f"üìä FINAL COMBINED DATASET: {len(df_combined):,} images")
print(f"{'='*60}")
print(f"Sources: {df_combined['Source'].value_counts().to_dict()}")
print(f"Labels: {df_combined['Labels'].nunique()} unique labels")
print(f"\nSample:")
print(df_combined.head())

## üóÑÔ∏è Cell 7: Advanced Dataset Class with Augmentation

In [None]:
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import random
import numpy as np
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")

class ChestXrayDataset(Dataset):
    """Dataset for chest X-ray training with augmentation - PaliGemma optimized"""
    
    def __init__(
        self,
        df: pd.DataFrame,
        processor,
        max_length: int = 384,
        is_train: bool = True,
    ):
        self.df = df.reset_index(drop=True)
        self.processor = processor
        self.max_length = max_length
        self.is_train = is_train
        
        # Medical-appropriate augmentation for training
        if is_train:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.RandomRotation(degrees=5),
                transforms.ColorJitter(brightness=0.1, contrast=0.1),
            ])
        else:
            self.transform = None
        
        # Cache failed image paths to avoid repeated errors
        self.failed_images = set()
        
        print(f"{'Train' if is_train else 'Eval'} dataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        
        # Load image with caching of failures
        img_path = row["ImagePath"]
        try:
            if img_path not in self.failed_images:
                image = Image.open(img_path).convert("RGB")
            else:
                image = Image.new("RGB", (224, 224), color=(128, 128, 128))
        except Exception as e:
            self.failed_images.add(img_path)
            image = Image.new("RGB", (224, 224), color=(128, 128, 128))
        
        # Apply augmentation
        if self.transform:
            image = self.transform(image)
        
        # Get report text
        text = str(row.get("Report", "")).strip()
        if not text or text == "nan":
            text = "Normal chest radiograph."
        
        # Create varied prompts for better generalization
        prompts = [
            "describe this chest xray",
            "explain this chest radiograph", 
            "analyze this xray image",
            "what does this chest xray show",
            "interpret this chest radiograph",
        ]
        prompt = random.choice(prompts) if self.is_train else "describe this chest xray"
        
        # Process with PaliGemma processor
        try:
            inputs = self.processor(
                text=prompt,
                images=image,
                suffix=text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
            )
        except Exception as e:
            # Fallback without suffix parameter
            full_text = f"{prompt}: {text}"
            inputs = self.processor(
                text=full_text,
                images=image,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
            )
        
        return {k: v.squeeze(0) for k, v in inputs.items()}

print("‚úÖ ChestXrayDataset class defined (optimized for PaliGemma)")

## üìä Cell 8: Create Train/Val/Test Splits

In [None]:
from sklearn.model_selection import train_test_split
import gc

# Use combined dataset from Kaggle downloads
df = df_combined.copy()

# ============================================================
# FILTER RARE CLASSES (need at least 10 samples per class for reliable stratified split)
# ============================================================
print("üìä Original label distribution:")
label_counts = df["Labels"].value_counts()
print(label_counts)

# Keep only classes with at least 10 samples (ensures 2+ in each split)
MIN_SAMPLES_PER_CLASS = 10
valid_labels = label_counts[label_counts >= MIN_SAMPLES_PER_CLASS].index.tolist()
rare_labels = label_counts[label_counts < MIN_SAMPLES_PER_CLASS].index.tolist()

if rare_labels:
    print(f"\n‚ö†Ô∏è Removing {len(rare_labels)} rare classes with <{MIN_SAMPLES_PER_CLASS} samples:")
    print(f"   {rare_labels[:10]}{'...' if len(rare_labels) > 10 else ''}")
    df = df[df["Labels"].isin(valid_labels)].reset_index(drop=True)
    print(f"   Remaining: {len(df):,} samples across {len(valid_labels)} classes")

# ============================================================
# DATASET SIZE CONTROL
# ============================================================
# Options:
#   - 5000: Fast training (~2-3 hours) - good for testing
#   - 8000: Better quality (~4-5 hours) - recommended
#   - Full: Best quality but longer training

MAX_SAMPLES = 8000  # Increased for better model quality
if len(df) > MAX_SAMPLES:
    print(f"\nüìä Sampling {MAX_SAMPLES:,} samples for training")
    # Sample proportionally from each class, ensure at least 10 per class
    df = df.groupby("Labels", group_keys=False).apply(
        lambda x: x.sample(max(10, min(len(x), int(MAX_SAMPLES * len(x) / len(df)))), random_state=42)
    ).reset_index(drop=True)

# Final filter
label_counts = df["Labels"].value_counts()
valid_for_split = label_counts[label_counts >= 10].index.tolist()
df = df[df["Labels"].isin(valid_for_split)].reset_index(drop=True)

print(f"\nüìä Final dataset: {len(df):,} samples, {df['Labels'].nunique()} classes")

# ============================================================
# STRATIFIED SPLIT: 80% train, 10% val, 10% test
# ============================================================
train_df, temp_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df["Labels"]
)

# Safe second split with fallback
temp_label_counts = temp_df["Labels"].value_counts()
problematic_labels = temp_label_counts[temp_label_counts < 2].index.tolist()

if problematic_labels:
    print(f"\n‚ö†Ô∏è Handling {len(problematic_labels)} edge case classes")
    temp_df_clean = temp_df[~temp_df["Labels"].isin(problematic_labels)].reset_index(drop=True)
    temp_df_problematic = temp_df[temp_df["Labels"].isin(problematic_labels)]
    
    if len(temp_df_clean) > 0 and temp_df_clean["Labels"].nunique() > 0:
        val_df_clean, test_df_clean = train_test_split(
            temp_df_clean, test_size=0.5, random_state=42, stratify=temp_df_clean["Labels"]
        )
        val_df_prob, test_df_prob = train_test_split(
            temp_df_problematic, test_size=0.5, random_state=42
        )
        val_df = pd.concat([val_df_clean, val_df_prob]).reset_index(drop=True)
        test_df = pd.concat([test_df_clean, test_df_prob]).reset_index(drop=True)
    else:
        val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
else:
    val_df, test_df = train_test_split(
        temp_df, test_size=0.5, random_state=42, stratify=temp_df["Labels"]
    )

print(f"\n‚úÖ Data Splits (Stratified by Labels):")
print(f"   Train: {len(train_df):,} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"   Val:   {len(val_df):,} samples ({len(val_df)/len(df)*100:.1f}%)")
print(f"   Test:  {len(test_df):,} samples ({len(test_df)/len(df)*100:.1f}%)")

print(f"\nüìã Label distribution (Train - Top 10):")
print(train_df["Labels"].value_counts().head(10))

# Memory cleanup
gc.collect()
torch.cuda.empty_cache()

# Create datasets
train_dataset = ChestXrayDataset(
    train_df, processor,
    max_length=config.max_length,
    is_train=True,
)

val_dataset = ChestXrayDataset(
    val_df, processor,
    max_length=config.max_length,
    is_train=False,
)

test_dataset = ChestXrayDataset(
    test_df, processor,
    max_length=config.max_length,
    is_train=False,
)

print(f"\n‚úÖ Datasets created successfully!")

## üéØ Cell 9: Data Collator and Metrics

In [None]:
from dataclasses import dataclass
from typing import Dict, List, Any
import torch

@dataclass
class DataCollator:
    """Simple collator for batching"""
    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        batch = {}
        for key in features[0].keys():
            batch[key] = torch.stack([f[key] for f in features])
        return batch

collator = DataCollator()

# Test collator with a sample
if len(train_dataset) > 0:
    sample = train_dataset[0]
    print(f"‚úÖ Sample batch keys: {list(sample.keys())}")
    for k, v in sample.items():
        print(f"   {k}: shape={v.shape}, dtype={v.dtype}")

In [None]:
# ============================================================
# EVALUATION METRICS SETUP
# ============================================================
import evaluate
import nltk

# Download NLTK data (suppress errors)
try:
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
    nltk.download('wordnet', quiet=True)
except:
    pass

print("üìä Loading evaluation metrics...")

# Load metrics with error handling
try:
    bleu_metric = evaluate.load("sacrebleu")
    print("‚úÖ BLEU metric loaded")
except Exception as e:
    print(f"‚ö†Ô∏è BLEU loading failed: {e}")
    bleu_metric = None

try:
    rouge_metric = evaluate.load("rouge")
    print("‚úÖ ROUGE metric loaded")
except Exception as e:
    print(f"‚ö†Ô∏è ROUGE loading failed: {e}")
    rouge_metric = None

def compute_metrics_safe(predictions, references):
    """Safely compute metrics with error handling"""
    results = {}
    
    if bleu_metric and predictions and references:
        try:
            bleu_result = bleu_metric.compute(
                predictions=predictions,
                references=[[r] for r in references]
            )
            results["bleu"] = bleu_result["score"]
        except Exception as e:
            print(f"‚ö†Ô∏è BLEU computation error: {e}")
            results["bleu"] = 0.0
    
    if rouge_metric and predictions and references:
        try:
            rouge_result = rouge_metric.compute(
                predictions=predictions,
                references=references
            )
            results["rouge1"] = rouge_result.get("rouge1", 0.0)
            results["rouge2"] = rouge_result.get("rouge2", 0.0)
            results["rougeL"] = rouge_result.get("rougeL", 0.0)
        except Exception as e:
            print(f"‚ö†Ô∏è ROUGE computation error: {e}")
    
    return results

print("‚úÖ Metrics setup complete")

## üöÄ Cell 10: Advanced Training with Callbacks

In [None]:
from transformers import Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
import os
import shutil

# ============================================================
# CONFIG FALLBACK (in case earlier cells weren't run)
# ============================================================
if 'config' not in dir():
    from dataclasses import dataclass
    @dataclass
    class Config:
        output_dir: str = "/content/medgemma_advanced_lora"
        num_epochs: int = 10
        learning_rate: float = 2e-4
        weight_decay: float = 0.01
        max_grad_norm: float = 1.0
    config = Config()
    print("‚ö†Ô∏è Using default config (run earlier cells for custom config)")

# ============================================================
# OUTPUT DIRECTORIES - Save to Google Drive for persistence!
# ============================================================
DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/ExplainMyXray_Models"
LOCAL_OUTPUT_DIR = config.output_dir

os.makedirs(LOCAL_OUTPUT_DIR, exist_ok=True)
os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)

print(f"üìÅ Local checkpoints: {LOCAL_OUTPUT_DIR}")
print(f"üíæ Drive backup: {DRIVE_OUTPUT_DIR}")

# ============================================================
# CRITICAL: Auto-backup callback that saves to Drive IMMEDIATELY
# ============================================================
class DriveBackupCallback(TrainerCallback):
    """Backup every checkpoint to Google Drive immediately after save"""
    
    def __init__(self, local_dir, drive_dir):
        self.local_dir = local_dir
        self.drive_dir = drive_dir
        self.backed_up = set()
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and state.global_step % 50 == 0:
            loss = logs.get("loss", "N/A")
            lr = logs.get("learning_rate", 0)
            if isinstance(loss, float):
                print(f"   Step {state.global_step}: loss={loss:.4f}, lr={lr:.2e}")
    
    def on_save(self, args, state, control, **kwargs):
        """Called AFTER trainer saves a checkpoint - backup to Drive immediately"""
        step = state.global_step
        checkpoint_name = f"checkpoint-{step}"
        src_path = os.path.join(self.local_dir, checkpoint_name)
        dst_path = os.path.join(self.drive_dir, checkpoint_name)
        
        if os.path.exists(src_path) and checkpoint_name not in self.backed_up:
            try:
                print(f"\n   üíæ Backing up {checkpoint_name} to Drive...", end=" ")
                shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
                self.backed_up.add(checkpoint_name)
                print("‚úÖ Done!")
            except Exception as e:
                print(f"‚ö†Ô∏è Failed: {e}")

# ============================================================
# Training Arguments - T4 OPTIMIZED + FREQUENT SAVES
# ============================================================
import transformers
tf_version = tuple(map(int, transformers.__version__.split('.')[:2]))

training_args_dict = {
    "output_dir": LOCAL_OUTPUT_DIR,
    
    # Batch size
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "gradient_accumulation_steps": 8,
    
    # Training duration
    "num_train_epochs": config.num_epochs,
    
    # Learning rate
    "learning_rate": config.learning_rate,
    "lr_scheduler_type": "cosine",
    "warmup_steps": 50,
    "weight_decay": config.weight_decay,
    
    # Gradient
    "max_grad_norm": config.max_grad_norm,
    "gradient_checkpointing": True,
    
    # Precision - T4 uses fp16!
    "fp16": True,
    "bf16": False,
    
    # Logging
    "logging_steps": 50,
    "logging_first_step": True,
    "report_to": "none",
    
    # ============================================================
    # EVALUATION - DISABLED FOR SPEED (was causing timeouts)
    # ============================================================
    "eval_strategy": "no",  # DISABLED - saves time!
    
    # ============================================================
    # CHECKPOINTING - SAVE EVERY 100 STEPS
    # ============================================================
    "save_strategy": "steps",
    "save_steps": 100,
    "save_total_limit": 5,
    "load_best_model_at_end": False,  # DISABLED since no eval
    
    # Performance optimizations
    "dataloader_num_workers": 2,  # Reduced to prevent memory issues
    "dataloader_pin_memory": True,
    "remove_unused_columns": False,
    
    # Speed optimizations
    "optim": "adamw_torch_fused",
    "torch_compile": False,
    
    "seed": 42,
}

# Add gradient_checkpointing_kwargs for newer versions
if tf_version >= (4, 40):
    training_args_dict["gradient_checkpointing_kwargs"] = {"use_reentrant": False}

training_args = TrainingArguments(**training_args_dict)

# Create backup callback
backup_callback = DriveBackupCallback(LOCAL_OUTPUT_DIR, DRIVE_OUTPUT_DIR)

# Initialize Trainer (NO EarlyStoppingCallback - we disabled eval for speed)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=None,  # No eval for faster training
    data_collator=collator,
    callbacks=[backup_callback],  # Only backup callback
)

print("\n" + "="*60)
print("üöÄ TRAINING CONFIGURATION (FAST MODE - NO EVAL)")
print("="*60)
print(f"Transformers Version: {transformers.__version__}")
print(f"Epochs: {config.num_epochs}")
print(f"Batch Size: 4 √ó 8 = 32 (effective)")
print(f"Learning Rate: {config.learning_rate}")
print(f"Train Samples: {len(train_dataset)}")
print(f"Precision: FP16 (T4 optimized)")
print(f"‚ö° Evaluation: DISABLED (faster training)")
print(f"‚ö° Checkpoint: Every 100 steps ‚Üí Auto-backup to Drive")
print(f"üìÇ Drive: {DRIVE_OUTPUT_DIR}")
print("="*60)

In [None]:
# ============================================================
# CELL 25: RESTORE CHECKPOINT FROM GOOGLE DRIVE
# ============================================================
# Run this cell to restore checkpoint-355 (or latest) before training

from google.colab import drive
import shutil
import os
from pathlib import Path

# Mount Google Drive (if not already mounted)
if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')
    print("‚úÖ Google Drive mounted")
else:
    print("‚úÖ Google Drive already mounted")

# ============================================================
# CHECKPOINT LOCATIONS - ALL POSSIBLE PATHS
# ============================================================
DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/ExplainMyXray_Models"
LOCAL_CHECKPOINT_DIR = "/content/medgemma_advanced_lora"

# Specific checkpoint paths to check
CHECKPOINT_PATHS = [
    # checkpoint-355 (your latest progress)
    f"{DRIVE_OUTPUT_DIR}/checkpoint-355",
    # checkpoint-250 backup location
    "/content/drive/MyDrive/medgemma_advanced_lora/checkpoint-250",
]

# Create directories
Path(LOCAL_CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
Path(DRIVE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print("üîç Searching for checkpoints...\n")

# ============================================================
# HELPER FUNCTIONS
# ============================================================
def is_valid_checkpoint(path):
    """Check if checkpoint has trainer_state.json (required for resume)"""
    return os.path.exists(os.path.join(path, "trainer_state.json"))

def copy_checkpoint(src_dir, dst_dir):
    """Copy checkpoint files, skipping Google Drive artifacts"""
    os.makedirs(dst_dir, exist_ok=True)
    skip_extensions = {'.gdoc', '.gsheet', '.gslides', '.gform', '.gdraw'}
    copied = 0
    for item in os.listdir(src_dir):
        src_path = os.path.join(src_dir, item)
        dst_path = os.path.join(dst_dir, item)
        _, ext = os.path.splitext(item)
        if ext.lower() in skip_extensions:
            continue
        if os.path.isfile(src_path):
            shutil.copy2(src_path, dst_path)
            copied += 1
        elif os.path.isdir(src_path):
            shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
            copied += 1
    return copied

# ============================================================
# FIND ALL VALID CHECKPOINTS
# ============================================================
found_checkpoints = []

# Check specific known paths
for cp_path in CHECKPOINT_PATHS:
    if os.path.exists(cp_path) and is_valid_checkpoint(cp_path):
        name = os.path.basename(cp_path)
        found_checkpoints.append((name, cp_path))
        print(f"‚úÖ Found: {name} at {cp_path}")

# Also scan ExplainMyXray_Models folder for any checkpoints
if os.path.exists(DRIVE_OUTPUT_DIR):
    for item in os.listdir(DRIVE_OUTPUT_DIR):
        full_path = os.path.join(DRIVE_OUTPUT_DIR, item)
        if os.path.isdir(full_path) and item.startswith("checkpoint-"):
            if is_valid_checkpoint(full_path):
                # Avoid duplicates
                if not any(name == item for name, _ in found_checkpoints):
                    found_checkpoints.append((item, full_path))
                    print(f"‚úÖ Found: {item} (in ExplainMyXray_Models)")

# ============================================================
# RESTORE CHECKPOINTS TO LOCAL
# ============================================================
checkpoint_restored = False

for name, src_path in found_checkpoints:
    dst_path = os.path.join(LOCAL_CHECKPOINT_DIR, name)
    if not os.path.exists(dst_path):
        copied = copy_checkpoint(src_path, dst_path)
        print(f"   ‚Üí Restored {copied} files to: {dst_path}")
        checkpoint_restored = True
    else:
        print(f"   ‚Üí Already exists locally: {name}")
        checkpoint_restored = True

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "="*60)
if checkpoint_restored:
    local_checkpoints = [d for d in os.listdir(LOCAL_CHECKPOINT_DIR) 
                         if os.path.isdir(f"{LOCAL_CHECKPOINT_DIR}/{d}") 
                         and d.startswith("checkpoint-")]
    print(f"üìÇ Local checkpoints: {local_checkpoints}")
    
    # Find highest step
    valid = []
    for cp in local_checkpoints:
        cp_path = os.path.join(LOCAL_CHECKPOINT_DIR, cp)
        if is_valid_checkpoint(cp_path):
            parts = cp.split("-")
            if len(parts) >= 2 and parts[1].isdigit():
                valid.append((int(parts[1]), cp))
    
    if valid:
        latest_step, latest_name = max(valid)
        print(f"üîÑ Will resume from: {latest_name} (step {latest_step})")
else:
    print("‚ö†Ô∏è No valid checkpoints found. Training will start from scratch.")
print("="*60)
print("\n‚úÖ Now run CELL 26 (Training)!")

In [None]:
# ============================================================
# CELL 26: TRAINING (STOP AT STEP 500 FOR TESTING)
# ============================================================

print("\nüèãÔ∏è Starting Training on T4 GPU...")
print(f"   Training {len(train_dataset)} samples")
print(f"   ‚ö° Will STOP at step 500 for testing")
print(f"   Checkpoints saved every 100 steps ‚Üí Auto-backup to Drive\n")

import gc
import os
import time
import shutil
from pathlib import Path

gc.collect()
torch.cuda.empty_cache()

start_time = time.time()

# ============================================================
# PATHS
# ============================================================
LOCAL_CHECKPOINT_DIR = "/content/medgemma_advanced_lora"
DRIVE_OUTPUT_DIR = "/content/drive/MyDrive/ExplainMyXray_Models"

Path(DRIVE_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# ============================================================
# OVERRIDE: Set max_steps to 500 for early stopping
# ============================================================
trainer.args.max_steps = 500  # STOP AT STEP 500!
trainer.args.num_train_epochs = 100  # Set high, max_steps will stop it
print(f"üéØ Training will stop at step 500 (for testing)")

# ============================================================
# FIND VALID CHECKPOINT TO RESUME FROM
# ============================================================
def is_valid_checkpoint(path):
    """Check if checkpoint has trainer_state.json (required for resume)"""
    return os.path.exists(os.path.join(path, "trainer_state.json"))

def fix_checkpoint_for_resume(checkpoint_path):
    """Remove scaler.pt if it exists to avoid loading issues"""
    scaler_path = os.path.join(checkpoint_path, "scaler.pt")
    if os.path.exists(scaler_path):
        os.remove(scaler_path)
        print(f"   ‚ö†Ô∏è Removed scaler.pt from checkpoint (fixes resume bug)")

resume_checkpoint = None

if os.path.exists(LOCAL_CHECKPOINT_DIR):
    all_items = os.listdir(LOCAL_CHECKPOINT_DIR)
    checkpoint_folders = [d for d in all_items 
                          if os.path.isdir(f"{LOCAL_CHECKPOINT_DIR}/{d}") 
                          and d.startswith("checkpoint-")]
    
    valid_checkpoints = []
    for cp in checkpoint_folders:
        cp_path = os.path.join(LOCAL_CHECKPOINT_DIR, cp)
        if is_valid_checkpoint(cp_path):
            parts = cp.split("-")
            if len(parts) >= 2 and parts[1].isdigit():
                step = int(parts[1])
                # Only resume if step < 500
                if step < 500:
                    valid_checkpoints.append((step, cp_path))
                    print(f"üìÇ Found valid checkpoint: {cp} (step {step})")
    
    if valid_checkpoints:
        latest_step, resume_checkpoint = max(valid_checkpoints)
        print(f"\n‚úÖ Will resume from: checkpoint-{latest_step}")
        # Fix the checkpoint to avoid scaler loading issue
        fix_checkpoint_for_resume(resume_checkpoint)
else:
    os.makedirs(LOCAL_CHECKPOINT_DIR, exist_ok=True)

if not resume_checkpoint:
    print("üÜï Starting training from scratch.")

# ============================================================
# RUN TRAINING
# ============================================================
try:
    if resume_checkpoint:
        print(f"\nüîÑ Resuming training from checkpoint...")
        train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
    else:
        print("\nüöÄ Starting fresh training...")
        train_result = trainer.train()
    
    elapsed = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"‚úÖ Training stopped at step 500 in {elapsed/60:.1f} minutes!")
    print(f"{'='*60}")
    print(f"\nüìä Training Results:")
    for key, value in train_result.metrics.items():
        if isinstance(value, float):
            print(f"   {key}: {value:.4f}")
        else:
            print(f"   {key}: {value}")
    
    # Save model for testing
    print(f"\nüíæ Saving model for testing...")
    TEST_MODEL_DIR = f"{DRIVE_OUTPUT_DIR}/checkpoint-500-test"
    trainer.save_model(TEST_MODEL_DIR)
    processor.save_pretrained(TEST_MODEL_DIR)
    print(f"‚úÖ Model saved to: {TEST_MODEL_DIR}")
    print(f"\nüéâ Ready for testing! Run the test cells next.")
        
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted!")
    trainer.save_model(f"{DRIVE_OUTPUT_DIR}/interrupted_model")
    print(f"‚úÖ Model saved. You can resume later.")
    
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("\n‚ùå CUDA Out of Memory!")
    raise

In [None]:
# ============================================================
# AFTER CELL 26: BACKUP CHECKPOINTS TO GOOGLE DRIVE
# ============================================================
# Run this cell to backup all checkpoints to Google Drive
# This ensures your progress is saved even if Colab disconnects

import shutil
import os
from pathlib import Path

LOCAL_CHECKPOINT = "/content/medgemma_advanced_lora"
DRIVE_CHECKPOINT = "/content/drive/MyDrive/ExplainMyXray_Checkpoints"

print("üíæ Backing up checkpoints to Google Drive...")

if os.path.exists(LOCAL_CHECKPOINT):
    # List all checkpoints
    all_files = os.listdir(LOCAL_CHECKPOINT)
    checkpoint_folders = [d for d in all_files if d.startswith("checkpoint-")]
    
    if checkpoint_folders:
        print(f"   Found {len(checkpoint_folders)} checkpoint(s):")
        for cp in sorted(checkpoint_folders):
            print(f"   - {cp}")
        
        # Copy everything to Drive
        shutil.copytree(LOCAL_CHECKPOINT, DRIVE_CHECKPOINT, dirs_exist_ok=True)
        
        print(f"\n‚úÖ All checkpoints backed up to Google Drive!")
        print(f"   Location: {DRIVE_CHECKPOINT}")
        print(f"\nüìå Your training progress is now SAFE!")
        print(f"   When you come back, run the 'BEFORE CELL 26' cell first,")
        print(f"   then run the training cell to resume.")
    else:
        print("‚ö†Ô∏è No checkpoints found in local directory.")
        print("   Training may not have started or saved yet.")
else:
    print("‚ö†Ô∏è No checkpoint directory found.")
    print("   Make sure training has run at least one save_steps iteration.")

# Show Drive usage
if os.path.exists(DRIVE_CHECKPOINT):
    total_size = sum(
        os.path.getsize(os.path.join(dirpath, filename))
        for dirpath, dirnames, filenames in os.walk(DRIVE_CHECKPOINT)
        for filename in filenames
    ) / (1024**2)  # Convert to MB
    print(f"\nüìä Total checkpoint size: {total_size:.1f} MB")

## üìà Cell 11: Training Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Get training history
history = trainer.state.log_history

# Extract metrics
train_losses = [x['loss'] for x in history if 'loss' in x]
eval_losses = [x['eval_loss'] for x in history if 'eval_loss' in x]
train_steps = [x['step'] for x in history if 'loss' in x]
eval_steps = [x['step'] for x in history if 'eval_loss' in x]
learning_rates = [x['learning_rate'] for x in history if 'learning_rate' in x]
lr_steps = [x['step'] for x in history if 'learning_rate' in x]

# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# 1. Loss Curves
ax1 = axes[0]
ax1.plot(train_steps, train_losses, label='Train Loss', alpha=0.7, color='blue')
if eval_losses:
    ax1.plot(eval_steps, eval_losses, label='Eval Loss', linewidth=2, color='orange')
ax1.set_xlabel('Steps')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Learning Rate Schedule
ax2 = axes[1]
if learning_rates:
    ax2.plot(lr_steps, learning_rates, color='green')
    ax2.set_xlabel('Steps')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('Cosine LR Schedule')
    ax2.grid(True, alpha=0.3)

# 3. Loss Distribution
ax3 = axes[2]
if train_losses:
    ax3.hist(train_losses, bins=30, alpha=0.7, label='Train', color='blue')
if eval_losses:
    ax3.hist(eval_losses, bins=20, alpha=0.7, label='Eval', color='orange')
ax3.set_xlabel('Loss')
ax3.set_ylabel('Frequency')
ax3.set_title('Loss Distribution')
ax3.legend()

plt.tight_layout()
plt.savefig(f"{config.output_dir}/training_curves.png", dpi=150)
plt.show()

# Print summary
print(f"\nüìä Training Summary:")
if train_losses:
    print(f"   Initial Loss: {train_losses[0]:.4f}")
    print(f"   Final Loss: {train_losses[-1]:.4f}")
    print(f"   Improvement: {(train_losses[0] - train_losses[-1])/train_losses[0]*100:.1f}%")
if eval_losses:
    print(f"   Best Eval Loss: {min(eval_losses):.4f}")

## üß™ Cell 12: Comprehensive Test Evaluation

In [None]:
print("üß™ Running Comprehensive Test Evaluation...\n")

# 1. Basic Evaluation
if len(test_dataset) > 0:
    test_results = trainer.evaluate(test_dataset)
    
    print("="*60)
    print("üìä TEST SET RESULTS")
    print("="*60)
    for key, value in test_results.items():
        if isinstance(value, float):
            print(f"   {key}: {value:.4f}")
    print("="*60)
else:
    print("‚ö†Ô∏è Test dataset is empty")

In [None]:
# ============================================================
# COMPREHENSIVE TESTING - Generate Predictions & Evaluate
# ============================================================
from tqdm.auto import tqdm
from PIL import Image
import numpy as np

def generate_prediction(model, processor, image_path, prompt="describe this chest xray:"):
    """Generate a single prediction"""
    try:
        image = Image.open(image_path).convert("RGB")
    except:
        return "Error loading image"
    
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )
    
    result = processor.decode(output[0], skip_special_tokens=True)
    return result.replace(prompt, "").strip()

# Generate predictions on test set
print("\nüîÆ Generating predictions on test set...")
predictions = []
ground_truths = []
num_test_samples = min(50, len(test_dataset))  # Limit for speed

model.eval()
for idx in tqdm(range(num_test_samples), desc="Generating"):
    row = test_df.iloc[idx]
    
    # Generate prediction
    pred = generate_prediction(model, processor, row["ImagePath"])
    gt = row["Report"]
    
    predictions.append(pred)
    ground_truths.append(gt)

# Display sample predictions
print("\n" + "="*60)
print("üìã SAMPLE PREDICTIONS")
print("="*60)
for i in range(min(5, len(predictions))):
    row = test_df.iloc[i]
    print(f"\nüì∑ Image: {row['ImageID']}")
    print(f"   Label: {row['Labels']}")
    print(f"   Ground Truth: {ground_truths[i][:100]}...")
    print(f"   Prediction:   {predictions[i][:100]}...")

In [None]:
# ============================================================
# ADVANCED METRICS - BLEU, ROUGE Scores
# ============================================================
import numpy as np

print("\nüìä Computing Advanced Metrics...")

if predictions and ground_truths:
    # Compute metrics using safe function
    metrics = compute_metrics_safe(predictions, ground_truths)
    
    print("\n" + "="*60)
    print("üìà ADVANCED EVALUATION METRICS")
    print("="*60)
    
    if "bleu" in metrics:
        print(f"   BLEU Score:      {metrics['bleu']:.2f}")
    if "rouge1" in metrics:
        print(f"   ROUGE-1 (F1):    {metrics['rouge1']:.4f}")
    if "rouge2" in metrics:
        print(f"   ROUGE-2 (F1):    {metrics['rouge2']:.4f}")
    if "rougeL" in metrics:
        print(f"   ROUGE-L (F1):    {metrics['rougeL']:.4f}")
    
    print("="*60)
    
    # Per-label accuracy analysis
    print("\nüìä Per-Label Analysis:")
    label_results = {}
    for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
        if i < len(test_df):
            label = test_df.iloc[i]["Labels"]
            if label not in label_results:
                label_results[label] = {"count": 0, "pred_lengths": []}
            label_results[label]["count"] += 1
            label_results[label]["pred_lengths"].append(len(pred.split()))
    
    for label, data in sorted(label_results.items(), key=lambda x: -x[1]["count"])[:5]:
        avg_len = np.mean(data["pred_lengths"]) if data["pred_lengths"] else 0
        print(f"   {label}: {data['count']} samples, avg prediction length: {avg_len:.1f} words")
else:
    print("‚ö†Ô∏è No predictions available for metrics computation")

## üíæ Cell 13: Save Model & Adapters

In [None]:
import os
import shutil
import json

# Save LoRA adapters
print("üíæ Saving LoRA adapters...")
model.save_pretrained(config.output_dir)
processor.save_pretrained(config.output_dir)

# Save training config & metrics
config_dict = {
    "model_id": config.model_id,
    "lora_r": config.lora_r,
    "lora_alpha": config.lora_alpha,
    "num_epochs": config.num_epochs,
    "learning_rate": config.learning_rate,
    "batch_size": config.batch_size,
    "max_length": config.max_length,
    "final_train_loss": train_losses[-1] if train_losses else None,
    "best_eval_loss": min(eval_losses) if eval_losses else None,
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "test_samples": len(test_dataset),
}

with open(f"{config.output_dir}/training_config.json", "w") as f:
    json.dump(config_dict, f, indent=2)

# Check size
total_size = sum(
    os.path.getsize(os.path.join(config.output_dir, f)) 
    for f in os.listdir(config.output_dir) 
    if os.path.isfile(os.path.join(config.output_dir, f))
)

print(f"\n‚úÖ Model saved to: {config.output_dir}")
print(f"üì¶ Total size: {total_size / 1024 / 1024:.2f} MB")
print(f"\nüìÅ Saved files:")
for f in sorted(os.listdir(config.output_dir)):
    fpath = os.path.join(config.output_dir, f)
    if os.path.isfile(fpath):
        size = os.path.getsize(fpath) / 1024
        print(f"   {f}: {size:.1f} KB")

In [None]:
# Download adapters to local machine
from google.colab import files

# Create zip archive
zip_path = "/content/medgemma_advanced_lora.zip"
shutil.make_archive("/content/medgemma_advanced_lora", "zip", config.output_dir)

print(f"üì¶ Created: {zip_path}")
print(f"üì• Downloading...")
files.download(zip_path)
print("\n‚úÖ Download complete! Extract and use with the inference script.")

## üéØ Cell 14: Interactive Inference Demo

In [None]:
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt

def explain_xray(image_path: str, model, processor):
    """Generate patient-friendly explanation for a chest X-ray"""
    
    image = Image.open(image_path).convert("RGB")
    prompt = "describe this chest xray in simple terms for a patient:"
    
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
        )
    
    result = processor.decode(output[0], skip_special_tokens=True)
    return result.replace(prompt, "").strip()

# Interactive demo with test images
print("\nü©ª INTERACTIVE X-RAY EXPLANATION DEMO\n")

# Get some test images
demo_images = test_df.sample(min(6, len(test_df))).reset_index(drop=True)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, (_, row) in enumerate(demo_images.iterrows()):
    if i >= 6:
        break
    
    img_path = row["ImagePath"]
    
    try:
        img = Image.open(img_path)
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f"{row['Labels'][:20]}", fontsize=10)
        
        # Generate explanation
        explanation = explain_xray(img_path, model, processor)
        print(f"üì∑ {row['ImageID']}:")
        print(f"   Label: {row['Labels']}")
        print(f"   AI Explanation: {explanation[:150]}...")
        print()
    except Exception as e:
        axes[i].text(0.5, 0.5, "Error", ha='center')
        axes[i].axis('off')

plt.tight_layout()
plt.savefig(f"{config.output_dir}/demo_predictions.png", dpi=150)
plt.show()

## üìä Cell 15: Final Summary & Metrics

In [None]:
print("\n" + "="*70)
print("üéâ TRAINING COMPLETE - FINAL SUMMARY")
print("="*70)

print(f"\nüì¶ Model: {config.model_id}")
print(f"üìÅ Output: {config.output_dir}")

print(f"\nüîß Training Configuration:")
print(f"   ‚Ä¢ Epochs: {config.num_epochs}")
print(f"   ‚Ä¢ Batch Size: {config.batch_size} √ó {config.gradient_accumulation_steps} = {config.batch_size * config.gradient_accumulation_steps}")
print(f"   ‚Ä¢ Learning Rate: {config.learning_rate}")
print(f"   ‚Ä¢ LoRA Rank: {config.lora_r}")
print(f"   ‚Ä¢ Max Length: {config.max_length}")

print(f"\nüìä Dataset (from Kaggle):")
print(f"   ‚Ä¢ Total: {len(df_combined):,} images")
print(f"   ‚Ä¢ Train: {len(train_dataset):,} samples")
print(f"   ‚Ä¢ Val: {len(val_dataset):,} samples")
print(f"   ‚Ä¢ Test: {len(test_dataset):,} samples")

if train_losses:
    print(f"\nüìà Training Metrics:")
    print(f"   ‚Ä¢ Initial Loss: {train_losses[0]:.4f}")
    print(f"   ‚Ä¢ Final Loss: {train_losses[-1]:.4f}")
    print(f"   ‚Ä¢ Best Eval Loss: {min(eval_losses) if eval_losses else 'N/A'}")
    print(f"   ‚Ä¢ Total Steps: {trainer.state.global_step}")

print(f"\nüíæ Saved Artifacts:")
for f in sorted(os.listdir(config.output_dir))[:5]:
    print(f"   ‚Ä¢ {f}")

print("\n" + "="*70)
print("‚úÖ Model ready for deployment!")
print("   Download medgemma_advanced_lora.zip and use with inference script.")
print("="*70)

---

## üîó Kaggle Dataset Links Used

1. **Chest X-Ray Pneumonia** (5,863 images):
   - https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia

2. **NIH Chest X-ray Sample** (5,606 images):
   - https://www.kaggle.com/datasets/nih-chest-xrays/sample

3. **COVID-19 Radiography Database** (21,165 images):
   - https://www.kaggle.com/datasets/tawsifurrahman/covid19-radiography-database

## üöÄ How to Use the Trained Model

```python
from peft import PeftModel
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image

# Load base model
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-pt-224")
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")

# Load LoRA adapters
model = PeftModel.from_pretrained(model, "./medgemma_advanced_lora")
model.eval()

# Inference
image = Image.open("chest_xray.png").convert("RGB")
inputs = processor(images=image, text="describe this chest xray:", return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=150)
print(processor.decode(output[0], skip_special_tokens=True))
```

## üìö Resources

- [PaliGemma Documentation](https://huggingface.co/google/paligemma-3b-pt-224)
- [PEFT/LoRA Guide](https://huggingface.co/docs/peft)
- [Kaggle Datasets API](https://www.kaggle.com/docs/api)

## ‚ö†Ô∏è Medical Disclaimer

This model is for educational purposes only. Always consult qualified healthcare professionals for medical diagnosis and treatment.