# MedGemma Pediatric Chest X-ray Fine-Tuning

**Important:** Make sure GPU is enabled!
- Go to: Runtime > Change runtime type > Hardware accelerator > T4 GPU

## Step 1: Check GPU and System Info

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Please enable GPU in Runtime settings.")

## Step 2: Install Required Packages

In [None]:
!pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm

## Step 3: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# TODO: Update this path to where you uploaded your dataset
DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset"

import os
if os.path.exists(DATASET_PATH):
    print(f"✓ Dataset found at: {DATASET_PATH}")
    print(f"  Files: {os.listdir(DATASET_PATH)}")
else:
    print(f"✗ Dataset NOT found at: {DATASET_PATH}")
    print("  Please update DATASET_PATH to match your Google Drive folder")

## Step 4: Import Libraries

In [None]:
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoProcessor,
    AutoModelForVision2Seq,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm import tqdm

print("✓ All libraries imported successfully")

## Step 5: Create Custom Dataset Class

In [None]:
class PediatricXrayDataset(Dataset):
    """Custom dataset for pediatric chest X-rays with reports"""
    
    def __init__(self, jsonl_path, dataset_root, processor, max_length=512):
        self.dataset_root = Path(dataset_root)
        self.processor = processor
        self.max_length = max_length
        
        # Load data from JSONL
        self.data = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        
        print(f"Loaded {len(self.data)} samples from {jsonl_path}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load image
        img_path = self.dataset_root / item['image']
        image = Image.open(img_path).convert('RGB')
        
        # Create prompt with age information
        age_group = item['age_group']
        prompt = f"Analyze this pediatric chest X-ray (age group: {age_group}) and provide a detailed radiology report."
        
        # Target report
        report = item['report']
        
        # Process with the model's processor
        encoding = self.processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        # Add labels (the report text)
        labels = self.processor.tokenizer(
            report,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )["input_ids"]
        
        encoding["labels"] = labels
        
        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        return encoding

print("✓ Dataset class defined")

## Step 6: Load Model with Quantization

**Note:** Replace MODEL_NAME with the actual MedGemma model identifier

In [None]:
# TODO: Replace with actual MedGemma model name
MODEL_NAME = "google/paligemma-3b-pt-224"  # Using PaliGemma as example
OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Quantization config for memory efficiency (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("Loading model and processor...")
print(f"Model: {MODEL_NAME}")

processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print(f"✓ Model loaded: {MODEL_NAME}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

## Step 7: Configure LoRA for Efficient Fine-Tuning

In [None]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"✓ LoRA applied")
print(f"  Trainable params: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
print(f"  Total params: {total_params:,}")

## Step 7.5: Test Dataset Loading (Debug)

In [None]:
# Test loading a single sample to verify everything works
print("Testing dataset loading with a single sample...")

test_jsonl = os.path.join(DATASET_PATH, "train.jsonl")
with open(test_jsonl, 'r') as f:
    sample = json.loads(f.readline())

print(f"Sample data: {sample['age_group']}, image: {sample['image']}")

# Try loading the image
img_path = os.path.join(DATASET_PATH, sample['image'])
test_img = Image.open(img_path).convert('RGB')
print(f"✓ Image loaded: {test_img.size}")

# Check tokenizer for image-related tokens
print("\nChecking for image tokens in tokenizer...")
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()

# Find all image-related tokens
image_tokens = [token for token in vocab.keys() if 'image' in token.lower()]
print(f"Image-related tokens found: {image_tokens}")

# Check for special tokens
if hasattr(tokenizer, 'boi_token'):
    print(f"Begin-of-image token: {tokenizer.boi_token}")
    IMAGE_TOKEN = tokenizer.boi_token
elif hasattr(tokenizer, 'image_token'):
    print(f"Image token: {tokenizer.image_token}")
    IMAGE_TOKEN = tokenizer.image_token
elif '<boi>' in vocab:
    print("Found <boi> token in vocabulary")
    IMAGE_TOKEN = '<boi>'
elif '<image>' in vocab:
    print("Found <image> token in vocabulary")
    IMAGE_TOKEN = '<image>'
else:
    # For Gemma3, try the begin-of-image token
    IMAGE_TOKEN = '<boi>'
    print(f"Using default: {IMAGE_TOKEN}")

# Test processor with image token
prompt = f"{IMAGE_TOKEN}Analyze this pediatric chest X-ray (age group: {sample['age_group']}) and provide a detailed radiology report."
print(f"\nPrompt with image token: {prompt[:100]}...")

try:
    test_encoding = processor(
        images=test_img,
        text=prompt,
        return_tensors="pt"
    )
    print(f"\n✓ Processor output keys: {test_encoding.keys()}")
    print(f"  Input shape: {test_encoding['input_ids'].shape}")
    if 'pixel_values' in test_encoding:
        print(f"  Pixel values shape: {test_encoding['pixel_values'].shape}")
    print("\n✓ Dataset loading test successful!")
except Exception as e:
    print(f"\n✗ Error: {e}")
    print("\nTrying alternative approach without text preprocessing...")
    
    # Alternative: Let processor handle everything
    try:
        test_encoding = processor(
            images=test_img,
            text=f"Analyze this pediatric chest X-ray (age group: {sample['age_group']}) and provide a detailed radiology report.",
            return_tensors="pt",
            add_special_tokens=True
        )
        print(f"\n✓ Alternative approach worked!")
        print(f"  Processor automatically added image tokens")
        IMAGE_TOKEN = None  # Let processor handle it
    except Exception as e2:
        print(f"✗ Alternative also failed: {e2}")

## Step 8: Load Training and Validation Datasets

In [None]:
train_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "train.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

val_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "val.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

print(f"✓ Datasets loaded")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

## Step 9: Configure Training Parameters

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_steps=100,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=100,
    save_total_limit=2,
    fp16=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

print("✓ Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

## Step 10: Initialize Trainer and Start Training

**This will take several hours!** Keep the browser tab open.

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("="*60)
print("Starting training...")
print("="*60)

trainer.train()

print("\n" + "="*60)
print("Training completed!")
print("="*60)

## Step 11: Save Fine-Tuned Model

In [None]:
# Save the fine-tuned LoRA adapters
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

print(f"✓ Model saved to: {OUTPUT_DIR}")
print("  You can now download this folder from Google Drive")

## Step 12: Test Inference on Sample

In [None]:
# Load a test image
test_jsonl = os.path.join(DATASET_PATH, "test.jsonl")
with open(test_jsonl, 'r') as f:
    test_sample = json.loads(f.readline())

test_img_path = os.path.join(DATASET_PATH, test_sample['image'])
test_image = Image.open(test_img_path).convert('RGB')

# Display the image
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 6))
plt.imshow(test_image)
plt.axis('off')
plt.title(f"Test X-ray (Age group: {test_sample['age_group']})")
plt.show()

# Detect image token (if needed)
tokenizer = processor.tokenizer
if hasattr(tokenizer, 'boi_token'):
    IMAGE_TOKEN = tokenizer.boi_token
    use_image_token = True
elif hasattr(tokenizer, 'image_token'):
    IMAGE_TOKEN = tokenizer.image_token
    use_image_token = True
elif '<boi>' in tokenizer.get_vocab():
    IMAGE_TOKEN = '<boi>'
    use_image_token = True
else:
    use_image_token = False

# Create prompt (with or without manual image token)
base_prompt = f"Analyze this pediatric chest X-ray (age group: {test_sample['age_group']}) and provide a detailed radiology report.\n\nReport: "
if use_image_token:
    prompt = f"{IMAGE_TOKEN}{base_prompt}"
else:
    prompt = base_prompt

# Generate report
inputs = processor(images=test_image, text=prompt, return_tensors="pt", add_special_tokens=True).to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        do_sample=True
    )

generated_report = processor.decode(outputs[0], skip_special_tokens=True)

print("="*60)
print("TEST INFERENCE")
print("="*60)
print(f"\nAge group: {test_sample['age_group']}")
print(f"\nGround truth report:\n{test_sample['report']}")
print(f"\nGenerated report:\n{generated_report}")

## Step 12.5: Alternative Inference (If Step 12 Fails)

If you're getting CUDA errors, restart the runtime and use this cell to load only the fine-tuned model for inference.

In [None]:
# Run this if Step 12 fails - it reloads model fresh for inference only

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import json
import os

# Clear any existing GPU memory
torch.cuda.empty_cache()

# Paths
BASE_MODEL = "google/medgemma-4b-it"
DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset_chest"

# Find the latest saved model
OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"
model_dirs = [d for d in os.listdir(OUTPUT_DIR) if d.startswith('final_model_')]
if model_dirs:
    latest_model = sorted(model_dirs)[-1]
    FINETUNED_MODEL = os.path.join(OUTPUT_DIR, latest_model)
    print(f"Loading fine-tuned model from: {FINETUNED_MODEL}")
else:
    print("No fine-tuned model found!")
    FINETUNED_MODEL = None

if FINETUNED_MODEL:
    # Load with 4-bit quantization for inference
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    # Load base model
    print("Loading base model...")
    base_model = AutoModelForImageTextToText.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto"
    )

    # Load fine-tuned LoRA adapters
    print("Loading fine-tuned LoRA adapters...")
    model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL)

    # Load processor
    processor = AutoProcessor.from_pretrained(FINETUNED_MODEL)
    IMAGE_TOKEN = processor.tokenizer.boi_token if hasattr(processor.tokenizer, 'boi_token') else '<start_of_image>'

    print("✓ Model loaded for inference")

    # Test inference
    test_jsonl = os.path.join(DATASET_PATH, "test.jsonl")
    with open(test_jsonl, 'r') as f:
        test_sample = json.loads(f.readline())

    test_img_path = os.path.join(DATASET_PATH, test_sample['image'])
    test_image = Image.open(test_img_path).convert('RGB')

    prompt = f"{IMAGE_TOKEN}Analyze this pediatric chest X-ray (age group: {test_sample['age_group']}) and provide a detailed radiology report.\n\nReport: "

    inputs = processor(images=test_image, text=prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    print("Generating report...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True,
            pad_token_id=processor.tokenizer.pad_token_id
        )

    generated_report = processor.decode(outputs[0], skip_special_tokens=True)

    print("="*60)
    print("INFERENCE RESULT")
    print("="*60)
    print(f"\nAge group: {test_sample['age_group']}")
    print(f"\nGround truth:\n{test_sample['report'][:200]}...")
    print(f"\nGenerated:\n{generated_report}")

## Step 13: Evaluate on Full Test Set

In [None]:
test_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "test.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

print(f"Evaluating on {len(test_dataset)} test samples...")
test_results = trainer.evaluate(test_dataset)
print(f"\n✓ Test loss: {test_results['eval_loss']:.4f}")