# MedGemma Pediatric Chest X-ray Fine-Tuning

**Important:** Make sure GPU is enabled!
- Go to: Runtime > Change runtime type > Hardware accelerator > T4 GPU

## Step 1: Check GPU and System Info

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! Please enable GPU in Runtime settings.")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 15.83 GB


## Step 2: Install Required Packages

In [2]:
!pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25h

## Step 3: Mount Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

# TODO: Update this path to where you uploaded your dataset
DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset_chest"

import os
if os.path.exists(DATASET_PATH):
    print(f"✓ Dataset found at: {DATASET_PATH}")
    print(f"  Files: {os.listdir(DATASET_PATH)}")
else:
    print(f"✗ Dataset NOT found at: {DATASET_PATH}")
    print("  Please update DATASET_PATH to match your Google Drive folder")

Mounted at /content/drive
✓ Dataset found at: /content/drive/MyDrive/pediatric_xray_dataset_chest
  Files: ['test.jsonl', 'val.jsonl', 'train.jsonl', 'dataset_stats.json', 'images']


## Step 4: Import Libraries

In [4]:
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from tqdm import tqdm

print("✓ All libraries imported successfully")

✓ All libraries imported successfully


In [5]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
print("CUDA_LAUNCH_BLOCKING is set to 1.")

CUDA_LAUNCH_BLOCKING is set to 1.


## Step 5: Create Custom Dataset Class

In [6]:
class PediatricXrayDataset(Dataset):
    """Custom dataset for pediatric chest X-rays with reports"""

    def __init__(self, jsonl_path, dataset_root, processor, max_length=512, image_token = ""):
        self.dataset_root = Path(dataset_root)
        self.processor = processor
        self.max_length = max_length
        self.image_token = image_token

        # Load data from JSONL
        self.data = []
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))

        print(f"Loaded {len(self.data)} samples from {jsonl_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load image
        img_path = self.dataset_root / item['image']
        image = Image.open(img_path).convert('RGB')
        #print(f"✓ Image loaded: {image.size}")

        # Create prompt with age information
        age_group = item['age_group']

        # Test processor with image token
        prompt = f"{self.image_token}Analyze this pediatric chest X-ray (age group: {age_group}) and provide a detailed radiology report."
        #print(f"\nPrompt with image token: {prompt[:100]}...")

        # Target report
        report = item['report']

        # Process with the model's processor
        encoding = self.processor(
            images=image,
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )

        # Add labels (the report text)
        labels = self.processor.tokenizer(
            report,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )["input_ids"]

        encoding["labels"] = labels

        # Remove batch dimension
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        return encoding

print("✓ Dataset class defined")

✓ Dataset class defined


## Step 6: Load Model with Quantization

**Note:** Replace MODEL_NAME with the actual MedGemma model identifier

In [7]:
# MedGemma model name
MODEL_NAME = "google/medgemma-4b-it"
OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Quantization config for memory efficiency (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print("Loading model and processor...")
print(f"Model: {MODEL_NAME}")

processor = AutoProcessor.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

print(f"✓ Model loaded: {MODEL_NAME}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")

Loading model and processor...
Model: google/medgemma-4b-it


processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

✓ Model loaded: google/medgemma-4b-it
  Parameters: 2.49B


## 6.5 Identify image token

In [8]:
# Check tokenizer for image-related tokens
print("\nChecking for image tokens in tokenizer...")
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()

# Find all image-related tokens
image_tokens = [token for token in vocab.keys() if 'image' in token.lower()]
print(f"Image-related tokens found: {image_tokens}")

# Check for special tokens
if hasattr(tokenizer, 'boi_token'):
    print(f"Begin-of-image token: {tokenizer.boi_token}")
    IMAGE_TOKEN = tokenizer.boi_token
elif hasattr(tokenizer, 'image_token'):
    print(f"Image token: {tokenizer.image_token}")
    IMAGE_TOKEN = tokenizer.image_token
elif '<boi>' in vocab:
    print("Found <boi> token in vocabulary")
    IMAGE_TOKEN = '<boi>'
elif '<image>' in vocab:
    print("Found <image> token in vocabulary")
    IMAGE_TOKEN = '<image>'
else:
    # For Gemma3, try the begin-of-image token
    IMAGE_TOKEN = '<boi>'
    print(f"Using default: {IMAGE_TOKEN}")



Checking for image tokens in tokenizer...
Image-related tokens found: ['Image', 'BufferedImage', '▁imagen', 'image', 'CategoryImage', '▁imaged', 'ImageType', 'imageColour', 'imageList', 'ImageView', 'ProductImages', '▁getImage', '<start_of_image>', '▁productImage', 'ImageBeforeText', '▁Imagen', 'ImageQueue', 'PhotoImage', 'imageNamed', 'ImageFilter', '▁imageView', 'imageBase', '▁ImageView', 'BackgroundImage', 'imageId', 'drawImage', 'imageUrl', '▁preimage', 'Imagery', 'backgroundImage', '<end_of_image>', '▁imagenes', '▁Imagery', '▁Images', 'Imagenes', '▁Image', 'Imagen', '▁AssetImage', 'images', '▁loadImage', 'ImageAsset', '▁ImageIcon', '<image_soft_token>', 'UIImage', 'CurrentImageData', 'imagery', '▁backgroundImage', '▁imagens', '▁imageName', 'getImage', 'ImageIcon', 'imagen', 'PictImage', '▁lgPlatformImage', 'imagem', 'imagens', 'ImageData', 'CurrentImage', 'ImageLayout', 'imageCache', '▁IMAGE', 'setImageResource', '▁createImage', '▁setImage', '▁UIImageView', 'TextImage', 'addImage

## Step 7: Configure LoRA for Efficient Fine-Tuning

In [9]:
# Prepare model for training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"✓ LoRA applied")
print(f"  Trainable params: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
print(f"  Total params: {total_params:,}")

✓ LoRA applied
  Trainable params: 16,394,240 (0.65%)
  Total params: 2,506,617,200


## Step 7.5: Test Dataset Loading (Debug)

In [10]:
# Test loading a single sample to verify everything works
print("Testing dataset loading with a single sample...")

test_jsonl = os.path.join(DATASET_PATH, "train.jsonl")
with open(test_jsonl, 'r') as f:
    sample = json.loads(f.readline())

print(f"Sample data: {sample['age_group']}, image: {sample['image']}")

# Try loading the image
img_path = os.path.join(DATASET_PATH, sample['image'])
test_img = Image.open(img_path).convert('RGB')
print(f"✓ Image loaded: {test_img.size}")

# Test processor with image token
prompt = f"{IMAGE_TOKEN}Analyze this pediatric chest X-ray (age group: {sample['age_group']}) and provide a detailed radiology report."
print(f"\nPrompt with image token: {prompt[:100]}...")

try:
    test_encoding = processor(
        images=test_img,
        text=prompt,
        return_tensors="pt"
    )
    print(f"\n✓ Processor output keys: {test_encoding.keys()}")
    print(f"  Input shape: {test_encoding['input_ids'].shape}")
    if 'pixel_values' in test_encoding:
        print(f"  Pixel values shape: {test_encoding['pixel_values'].shape}")
    print("\n✓ Dataset loading test successful!")
except Exception as e:
    print(f"\n✗ Error: {e}")
    print("\nTrying alternative approach without text preprocessing...")

    # Alternative: Let processor handle everything
    try:
        test_encoding = processor(
            images=test_img,
            text=f"Analyze this pediatric chest X-ray (age group: {sample['age_group']}) and provide a detailed radiology report.",
            return_tensors="pt",
            add_special_tokens=True
        )
        print(f"\n✓ Alternative approach worked!")
        print(f"  Processor automatically added image tokens")
        IMAGE_TOKEN = None  # Let processor handle it
    except Exception as e2:
        print(f"✗ Alternative also failed: {e2}")

Testing dataset loading with a single sample...
Sample data: adolescent, image: images/1.3.12.2.1107.5.3.56.4126.11.202501010244170058.png
✓ Image loaded: (896, 896)

Prompt with image token: <start_of_image>Analyze this pediatric chest X-ray (age group: adolescent) and provide a detailed ra...

✓ Processor output keys: KeysView({'input_ids': tensor([[     2,    108, 255999, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
         262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144, 262144,
       

## Step 8: Load Training and Validation Datasets

In [11]:
train_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "train.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor,
    image_token = IMAGE_TOKEN
)

val_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "val.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor,
    image_token = IMAGE_TOKEN
)

print(f"✓ Datasets loaded")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

Loaded 168 samples from /content/drive/MyDrive/pediatric_xray_dataset_chest/train.jsonl
Loaded 13 samples from /content/drive/MyDrive/pediatric_xray_dataset_chest/val.jsonl
✓ Datasets loaded
  Training samples: 168
  Validation samples: 13


## Step 9: Configure Training Parameters

In [12]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_steps=100,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=100,
    save_total_limit=2,
    fp16=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
)

print("✓ Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

✓ Training configuration:
  Epochs: 3
  Batch size: 1
  Gradient accumulation: 8
  Effective batch size: 8
  Learning rate: 0.0002


## Step 10: Initialize Trainer and Start Training

**This will take several hours!** Keep the browser tab open.

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("="*60)
print("Starting training...")
print("="*60)

trainer.train()

print("\n" + "="*60)
print("Training completed!")
print("="*60)

Starting training...


  return fn(*args, **kwargs)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
50,7.6504,4.181233



Training completed!


## Step 11: Save Fine-Tuned Model

In [14]:
import datetime
import os

# Generate a timestamp for a unique folder name
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
final_model_dir = os.path.join(OUTPUT_DIR, f"final_model_{current_time}")

# Create the unique output directory if it doesn't exist
os.makedirs(final_model_dir, exist_ok=True)

# Save the fine-tuned LoRA adapters and processor to the unique directory
model.save_pretrained(final_model_dir)
processor.save_pretrained(final_model_dir)

print(f"✓ Model saved to: {final_model_dir}")
print("  You can now download this folder from Google Drive")

✓ Model saved to: /content/drive/MyDrive/medgemma_pediatric_finetuned/final_model_20260109_081117
  You can now download this folder from Google Drive


## Step 12: Test Inference on Sample

In [17]:
import torch
import gc

# Clear GPU cache first
torch.cuda.empty_cache()
gc.collect()

# Load a test image
test_jsonl = os.path.join(DATASET_PATH, "test.jsonl")
with open(test_jsonl, 'r') as f:
    test_sample = json.loads(f.readline())

test_img_path = os.path.join(DATASET_PATH, test_sample['image'])
test_image = Image.open(test_img_path).convert('RGB')

print(f"Testing image: {test_sample['image']}")
print(f"Age group: {test_sample['age_group']}\n")

# Create prompt
prompt = f"{IMAGE_TOKEN}Analyze this pediatric chest X-ray (age group: {test_sample['age_group']}) and provide a detailed radiology report.\n\nReport: "

# Process inputs
inputs = processor(images=test_image, text=prompt, return_tensors="pt", add_special_tokens=True)

# Move inputs to device manually (safer than .to(model.device))
inputs = {k: v.to('cuda') for k, v in inputs.items()}

print("Generating report (this may take a minute)...")

# Generate report with error handling
try:
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            do_sample=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id
        )

    generated_report = processor.decode(outputs[0], skip_special_tokens=True)

    print("="*60)
    print("TEST INFERENCE")
    print("="*60)
    print(f"\nAge group: {test_sample['age_group']}")
    print(f"\nGround truth report:\n{test_sample['report']}")
    print(f"\nGenerated report:\n{generated_report}")

except RuntimeError as e:
    if "out of memory" in str(e).lower() or "cuda" in str(e).lower():
        print(f"\n⚠️ GPU Error: {str(e)[:100]}...")
        print("\n💡 Solution: Runtime → Restart runtime, then run Step 12.5 instead")
    else:
        print(f"\n❌ Error: {e}")
        raise

# Clean up
del inputs, outputs
torch.cuda.empty_cache()
gc.collect()

AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## Step 12.5: Inference After Runtime Restart

Run this if Step 12 gives CUDA errors

Instructions: Runtime → Restart runtime → Run ONLY this cell

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import json
import os

print("🔄 Loading model fresh for inference...")

# Paths - UPDATE THESE IF NEEDED
BASE_MODEL = "google/medgemma-4b-it"
DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset_chest"
OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"

# Mount Drive if not already mounted
try:
    from google.colab import drive
    if not os.path.exists('/content/drive'):
        drive.mount('/content/drive')
except:
    pass

# Find the latest fine-tuned model
model_dirs = [d for d in os.listdir(OUTPUT_DIR) if d.startswith('final_model_')]
if not model_dirs:
    print("❌ No fine-tuned model found!")
    print(f"   Looking in: {OUTPUT_DIR}")
else:
    latest_model = sorted(model_dirs)[-1]
    FINETUNED_MODEL_PATH = os.path.join(OUTPUT_DIR, latest_model)
    print(f"✓ Found model: {latest_model}")

    # Load with 4-bit quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    print("Loading base model...")
    base_model = AutoModelForImageTextToText.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto"
    )

    print("Loading fine-tuned LoRA adapters...")
    model = PeftModel.from_pretrained(base_model, FINETUNED_MODEL_PATH)

    print("Loading processor...")
    processor = AutoProcessor.from_pretrained(BASE_MODEL)  # Load from base model, not finetuned

    # Get image token
    IMAGE_TOKEN = processor.tokenizer.boi_token if hasattr(processor.tokenizer, 'boi_token') else '<start_of_image>'

    # CRITICAL: Check vocabulary size match
    # Gemma3 stores vocab_size in text_config
    if hasattr(model.config, 'vocab_size'):
        model_vocab_size = model.config.vocab_size
    elif hasattr(model.config, 'text_config') and hasattr(model.config.text_config, 'vocab_size'):
        model_vocab_size = model.config.text_config.vocab_size
    else:
        model_vocab_size = model.get_input_embeddings().weight.shape[0]

    print(f"\nModel vocab size: {model_vocab_size}")
    print(f"Tokenizer vocab size: {len(processor.tokenizer)}")

    # Resize model embeddings if needed
    if len(processor.tokenizer) != model_vocab_size:
        print("⚠️ Resizing model token embeddings to match tokenizer...")
        model.resize_token_embeddings(len(processor.tokenizer))

    print("✓ Model ready for inference!\n")

    # Load test sample
    test_jsonl = os.path.join(DATASET_PATH, "test.jsonl")
    with open(test_jsonl, 'r') as f:
        test_sample = json.loads(f.readline())

    test_img_path = os.path.join(DATASET_PATH, test_sample['image'])
    test_image = Image.open(test_img_path).convert('RGB')

    print(f"Test image: {test_sample['image']}")
    print(f"Age group: {test_sample['age_group']}\n")

    # Create prompt
    prompt = f"{IMAGE_TOKEN}Analyze this pediatric chest X-ray (age group: {test_sample['age_group']}) and provide a detailed radiology report.\n\nReport: "

    # Process inputs
    inputs = processor(images=test_image, text=prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Debug: Check input token IDs are valid
    print(f"Input token IDs range: {inputs['input_ids'].min().item()} to {inputs['input_ids'].max().item()}")
    print(f"Max valid token ID: {len(processor.tokenizer) - 1}")

    # Check for invalid tokens
    invalid_tokens = inputs['input_ids'] >= len(processor.tokenizer)
    if invalid_tokens.any():
        print(f"⚠️ WARNING: Found {invalid_tokens.sum().item()} invalid token IDs!")
        print("Clamping to valid range...")
        inputs['input_ids'] = torch.clamp(inputs['input_ids'], 0, len(processor.tokenizer) - 1)

    print("\nGenerating report...")

    try:
        with torch.no_grad():
            # Use greedy decoding first (more stable)
            outputs = model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=False,  # Greedy decoding - more stable
                pad_token_id=processor.tokenizer.pad_token_id if processor.tokenizer.pad_token_id is not None else processor.tokenizer.eos_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                bos_token_id=processor.tokenizer.bos_token_id if processor.tokenizer.bos_token_id is not None else 1
            )

        generated_report = processor.decode(outputs[0], skip_special_tokens=True)

        print("\n" + "="*70)
        print("TEST INFERENCE RESULTS")
        print("="*70)
        print(f"\n📋 Age Group: {test_sample['age_group']}")
        print(f"\n✅ Ground Truth Report:")
        print("-" * 70)
        print(test_sample['report'])
        print("\n🤖 Generated Report:")
        print("-" * 70)
        print(generated_report)
        print("\n" + "="*70)

    except RuntimeError as e:
        print(f"\n❌ Error during generation: {e}")
        print("\nTrying alternative approach with temperature sampling...")

        # Alternative: Try with lower temperature and nucleus sampling
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=128,
                    temperature=0.3,
                    do_sample=True,
                    top_p=0.9,
                    top_k=50,
                    pad_token_id=processor.tokenizer.pad_token_id if processor.tokenizer.pad_token_id is not None else processor.tokenizer.eos_token_id,
                    eos_token_id=processor.tokenizer.eos_token_id
                )

            generated_report = processor.decode(outputs[0], skip_special_tokens=True)
            print("\n✓ Alternative approach worked!")
            print(f"\nGenerated: {generated_report}")

        except Exception as e2:
            print(f"\n❌ Alternative also failed: {e2}")
            print("\nDebugging info:")
            print(f"  Input shape: {inputs['input_ids'].shape}")
            print(f"  Input device: {inputs['input_ids'].device}")
            print(f"  Model device: {next(model.parameters()).device}")

## Step 13: Evaluate on Full Test Set

In [None]:
test_dataset = PediatricXrayDataset(
    jsonl_path=os.path.join(DATASET_PATH, "test.jsonl"),
    dataset_root=DATASET_PATH,
    processor=processor,
    image_token = IMAGE_TOKEN
)

print(f"Evaluating on {len(test_dataset)} test samples...")
test_results = trainer.evaluate(test_dataset)
print(f"\n✓ Test loss: {test_results['eval_loss']:.4f}")

## Diagnostic: Check what's in your saved model

In [2]:
import os
import json
from transformers import AutoProcessor

OUTPUT_DIR = "/content/drive/MyDrive/medgemma_pediatric_finetuned"

# Find latest model
model_dirs = [d for d in os.listdir(OUTPUT_DIR) if d.startswith('final_model_')]
latest_model = sorted(model_dirs)[-1]
FINETUNED_MODEL_PATH = os.path.join(OUTPUT_DIR, latest_model)

print(f"Checking: {FINETUNED_MODEL_PATH}\n")

# Check what files are there
files = os.listdir(FINETUNED_MODEL_PATH)
print("Files in model directory:")
for f in files:
    print(f"  - {f}")

# Check config
if 'adapter_config.json' in files:
    with open(os.path.join(FINETUNED_MODEL_PATH, 'adapter_config.json'), 'r') as f:
        config = json.load(f)
    print("\nLoRA Config:")
    print(f"  Base model: {config.get('base_model_name_or_path', 'N/A')}")
    print(f"  LoRA rank: {config.get('r', 'N/A')}")
    print(f"  Target modules: {config.get('target_modules', 'N/A')}")

# Check if processor was saved
processor_files = ['tokenizer_config.json', 'special_tokens_map.json']
saved_processor = any(f in files for f in processor_files)
print(f"\nProcessor saved: {saved_processor}")

# If processor is there, check it
if saved_processor:
    try:
        proc = AutoProcessor.from_pretrained(FINETUNED_MODEL_PATH)
        print(f"  Vocab size: {len(proc.tokenizer)}")
        print(f"  PAD token: {proc.tokenizer.pad_token} (ID: {proc.tokenizer.pad_token_id})")
        print(f"  EOS token: {proc.tokenizer.eos_token} (ID: {proc.tokenizer.eos_token_id})")
        print(f"  BOS token: {proc.tokenizer.bos_token} (ID: {proc.tokenizer.bos_token_id})")
    except Exception as e:
        print(f"  Error loading processor: {e}")

print("\n" + "="*70)
print("Recommendation:")
print("="*70)
if not saved_processor:
    print("⚠️ Processor was not saved with the model!")
    print("   This is why we load it from the base model instead.")
else:
    print("✓ Processor was saved correctly.")

Checking: /content/drive/MyDrive/medgemma_pediatric_finetuned/final_model_20260109_081117

Files in model directory:
  - README.md
  - adapter_model.safetensors
  - adapter_config.json
  - preprocessor_config.json
  - chat_template.jinja
  - tokenizer_config.json
  - special_tokens_map.json
  - added_tokens.json
  - tokenizer.model
  - tokenizer.json
  - processor_config.json

LoRA Config:
  Base model: google/medgemma-4b-it
  LoRA rank: 8
  Target modules: ['q_proj', 'o_proj', 'up_proj', 'k_proj', 'v_proj', 'down_proj', 'gate_proj']

Processor saved: True
  Vocab size: 262145
  PAD token: <pad> (ID: 0)
  EOS token: <eos> (ID: 1)
  BOS token: <bos> (ID: 2)

Recommendation:
✓ Processor was saved correctly.
