# MedGemma Pediatric Chest X-ray Fine-Tuning

**Important:** Make sure GPU is enabled!
- Go to: Runtime > Change runtime type > Hardware accelerator > T4 GPU

## Step 1: Check GPU and System Info

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Please enable GPU in Runtime settings.")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 15.83 GB


## Step 2: Install Required Packages

In [2]:
!pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm

## Step 3: Mount Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

# TODO: Update this path to where you uploaded your dataset
DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset_chest"

import os
if os.path.exists(DATASET_PATH):
    print(f"✓ Dataset found at: {DATASET_PATH}")
    print(f"  Files: {os.listdir(DATASET_PATH)}")
else:
    print(f"✗ Dataset NOT found at: {DATASET_PATH}")
    print("  Please update DATASET_PATH to match your Google Drive folder")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Dataset found at: /content/drive/MyDrive/pediatric_xray_dataset_chest
  Files: ['test.jsonl', 'val.jsonl', 'train.jsonl', 'dataset_stats.json', 'images']


## Step 4: Import Libraries

In [8]:
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm import tqdm

print("✓ All libraries imported successfully")

✓ All libraries imported successfully


## Step 5: Create Custom Dataset Class

In [9]:
class PediatricXrayDataset(Dataset):
    """Custom dataset for pediatric chest X-rays with reports"""

    def __init__(self, jsonl_path, dataset_root, processor, max_length=512):
        self.dataset_root = Path(dataset_root)
        self.processor = processor
        self.max_length = max_length

        # Load data from JSONL
        self.data = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))

        print(f"Loaded {len(self.data)} samples from {jsonl_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load image
        img_path = self.dataset_root / item['image']
        image = Image.open(img_path).convert('RGB')

        # Create prompt with age information
        age_group = item['age_group']
        prompt = f"Analyze this pediatric chest X-ray (age group: {age_group}) and provide a detailed radiology report."

        # Target report
        report = item['report']

        # Process with the model's processor
        encoding = self.processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )

        # Add labels (the report text)
        labels = self.processor.tokenizer(
            report,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )["input_ids"]

        encoding["labels"] = labels

        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        return encoding

print("✓ Dataset class defined")

✓ Dataset class defined


## Step 6: Load Model with Quantization

**Note:** Replace MODEL_NAME with the actual MedGemma model identifier

In [11]:
# MedGemma model name
MODEL_NAME = "google/medgemma-4b-it"
OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Quantization config for memory efficiency (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("Loading model and processor...")
print(f"Model: {MODEL_NAME}")

processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

print(f"✓ Model loaded: {MODEL_NAME}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

Loading model and processor...
Model: google/medgemma-4b-it


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

✓ Model loaded: google/medgemma-4b-it
  Parameters: 2.49B


## Step 7: Configure LoRA for Efficient Fine-Tuning

In [12]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"✓ LoRA applied")
print(f"  Trainable params: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
print(f"  Total params: {total_params:,}")

✓ LoRA applied
  Trainable params: 16,394,240 (0.65%)
  Total params: 2,506,617,200


## Step 8: Load Training and Validation Datasets

In [13]:
train_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "train.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

val_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "val.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

print(f"✓ Datasets loaded")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

Loaded 168 samples from /content/drive/MyDrive/pediatric_xray_dataset_chest/train.jsonl
Loaded 13 samples from /content/drive/MyDrive/pediatric_xray_dataset_chest/val.jsonl
✓ Datasets loaded
  Training samples: 168
  Validation samples: 13


## Step 9: Configure Training Parameters

In [14]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_steps=100,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=100,
    save_total_limit=2,
    fp16=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

print("✓ Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

✓ Training configuration:
  Epochs: 3
  Batch size: 1
  Gradient accumulation: 8
  Effective batch size: 8
  Learning rate: 0.0002


## Step 10: Initialize Trainer and Start Training

**This will take several hours!** Keep the browser tab open.

In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("="*60)
print("Starting training...")
print("="*60)

trainer.train()

print("\n" + "="*60)
print("Training completed!")
print("="*60)

ImportError: cannot import name 'SFTTrainer' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)

## Step 11: Save Fine-Tuned Model

In [None]:
# Save the fine-tuned LoRA adapters
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

print(f"✓ Model saved to: {OUTPUT_DIR}")
print("  You can now download this folder from Google Drive")

## Step 12: Test Inference on Sample

In [None]:
# Load a test image
test_jsonl = os.path.join(DATASET_PATH, "test.jsonl")
with open(test_jsonl, 'r') as f:
    test_sample = json.loads(f.readline())

test_img_path = os.path.join(DATASET_PATH, test_sample['image'])
test_image = Image.open(test_img_path).convert('RGB')

# Display the image
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 6))
plt.imshow(test_image)
plt.axis('off')
plt.title(f"Test X-ray (Age group: {test_sample['age_group']})")
plt.show()

# Create prompt
prompt = f"Analyze this pediatric chest X-ray (age group: {test_sample['age_group']}) and provide a detailed radiology report."

# Generate report
inputs = processor(text=prompt, images=test_image, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        do_sample=True
    )

generated_report = processor.decode(outputs[0], skip_special_tokens=True)

print("="*60)
print("TEST INFERENCE")
print("="*60)
print(f"\nAge group: {test_sample['age_group']}")
print(f"\nGround truth report:\n{test_sample['report']}")
print(f"\nGenerated report:\n{generated_report}")

## Step 13: Evaluate on Full Test Set

In [None]:
test_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "test.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor
)

print(f"Evaluating on {len(test_dataset)} test samples...")
test_results = trainer.evaluate(test_dataset)
print(f"\n✓ Test loss: {test_results['eval_loss']:.4f}")