# üî¨ ExplainMyXray - Model Testing Notebook

Test your fine-tuned PaliGemma model on chest X-ray images.

**Checkpoints Available:**
- `interrupted_checkpoint` - Latest training progress
- `checkpoint-250` - Step 250 checkpoint

---

## 1Ô∏è‚É£ Install Dependencies

In [None]:
!pip install -q transformers accelerate bitsandbytes peft pillow matplotlib

## 2Ô∏è‚É£ Mount Google Drive & Setup Paths

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

# Checkpoint paths
DRIVE_CHECKPOINT = "/content/drive/MyDrive/ExplainMyXray_Models/interrupted_checkpoint"
CHECKPOINT_250 = "/content/drive/MyDrive/ExplainMyXray_Models/checkpoint-250"  # If you saved it

# Check what's available
print("üìÅ Checking available checkpoints...")
if os.path.exists(DRIVE_CHECKPOINT):
    print(f"‚úÖ interrupted_checkpoint found")
    files = os.listdir(DRIVE_CHECKPOINT)
    print(f"   Files: {files[:5]}..." if len(files) > 5 else f"   Files: {files}")
else:
    print(f"‚ùå interrupted_checkpoint not found")

if os.path.exists(CHECKPOINT_250):
    print(f"‚úÖ checkpoint-250 found")
else:
    print(f"‚ö†Ô∏è checkpoint-250 not found on Drive")

## 3Ô∏è‚É£ HuggingFace Authentication

In [None]:
import os
from huggingface_hub import login

# ‚ö†Ô∏è DO NOT hardcode your token here!
# Set HF_TOKEN in .env file or run: huggingface-cli login
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("‚úÖ Logged in via HF_TOKEN environment variable")
else:
    login()  # Interactive login
    print("‚úÖ HuggingFace authentication successful!")

## 4Ô∏è‚É£ Load Model & Processor

In [None]:
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel

# Base model name
BASE_MODEL = "google/paligemma-3b-pt-224"

# Choose which checkpoint to load
CHECKPOINT_PATH = DRIVE_CHECKPOINT  # Change to CHECKPOINT_250 if needed

print(f"üì• Loading from: {CHECKPOINT_PATH}")

# Quantization config (same as training)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load processor
print("Loading processor...")
processor = AutoProcessor.from_pretrained(BASE_MODEL, token=HF_TOKEN)

# Load base model
print("Loading base model (4-bit quantized)...")
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    token=HF_TOKEN,
    torch_dtype=torch.float16,
)

# Load LoRA adapter
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, CHECKPOINT_PATH)
model.eval()

print("‚úÖ Model loaded successfully!")
print(f"   Device: {next(model.parameters()).device}")

## 5Ô∏è‚É£ Define Inference Function

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

def explain_xray(image_path, prompt="Explain this chest X-ray:", max_new_tokens=256):
    """
    Generate explanation for a chest X-ray image.
    
    Args:
        image_path: Path to the X-ray image
        prompt: Text prompt for the model
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        Generated explanation text
    """
    # Load image
    image = Image.open(image_path).convert("RGB")
    
    # Prepare inputs
    inputs = processor(
        text=prompt,
        images=image,
        return_tensors="pt"
    ).to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
        )
    
    # Decode
    generated_text = processor.decode(outputs[0], skip_special_tokens=True)
    
    # Remove prompt from output if present
    if prompt in generated_text:
        generated_text = generated_text.split(prompt)[-1].strip()
    
    return generated_text, image


def display_result(image_path, prompt="Explain this chest X-ray:"):
    """Display image and its generated explanation."""
    explanation, image = explain_xray(image_path, prompt)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    ax.imshow(image, cmap='gray')
    ax.axis('off')
    ax.set_title("Chest X-Ray", fontsize=14)
    plt.show()
    
    print("\n" + "="*60)
    print("üî¨ MODEL EXPLANATION:")
    print("="*60)
    print(explanation)
    print("="*60)
    
    return explanation

print("‚úÖ Inference functions ready!")

## 6Ô∏è‚É£ Download Test Images (Kaggle Dataset)

In [None]:
import os
import json

# Setup Kaggle credentials
os.makedirs('/root/.kaggle', exist_ok=True)

kaggle_creds = {
    "username": "your_kaggle_username",  # Replace with your username
    "key": "your_kaggle_key"  # Replace with your key
}

# Or use the userdata approach if you have secrets set up
try:
    from google.colab import userdata
    kaggle_creds = {
        "username": userdata.get('KAGGLE_USERNAME'),
        "key": userdata.get('KAGGLE_KEY')
    }
except:
    pass

with open('/root/.kaggle/kaggle.json', 'w') as f:
    json.dump(kaggle_creds, f)
os.chmod('/root/.kaggle/kaggle.json', 0o600)

# Download a small subset for testing
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia -p /content/test_data --unzip

print("\n‚úÖ Test data downloaded!")

## 7Ô∏è‚É£ Test on Sample Images

In [None]:
import glob
import random

# Find test images
test_dir = "/content/test_data/chest_xray/test"

normal_images = glob.glob(f"{test_dir}/NORMAL/*.jpeg")
pneumonia_images = glob.glob(f"{test_dir}/PNEUMONIA/*.jpeg")

print(f"Found {len(normal_images)} normal images")
print(f"Found {len(pneumonia_images)} pneumonia images")

In [None]:
# Test on a NORMAL image
print("\n" + "üü¢"*30)
print("TESTING ON NORMAL X-RAY")
print("üü¢"*30)

if normal_images:
    test_normal = random.choice(normal_images)
    print(f"\nImage: {os.path.basename(test_normal)}")
    explanation = display_result(test_normal)
else:
    print("No normal images found!")

In [None]:
# Test on a PNEUMONIA image
print("\n" + "üî¥"*30)
print("TESTING ON PNEUMONIA X-RAY")
print("üî¥"*30)

if pneumonia_images:
    test_pneumonia = random.choice(pneumonia_images)
    print(f"\nImage: {os.path.basename(test_pneumonia)}")
    explanation = display_result(test_pneumonia)
else:
    print("No pneumonia images found!")

## 8Ô∏è‚É£ Test with Custom Image (Upload)

In [None]:
from google.colab import files

print("üì§ Upload a chest X-ray image to test:")
uploaded = files.upload()

for filename in uploaded.keys():
    print(f"\nTesting: {filename}")
    explanation = display_result(filename)

## 9Ô∏è‚É£ Batch Testing & Evaluation

In [None]:
def batch_test(image_paths, label="Test", num_samples=5):
    """
    Test multiple images and collect results.
    """
    results = []
    samples = random.sample(image_paths, min(num_samples, len(image_paths)))
    
    for i, img_path in enumerate(samples, 1):
        print(f"\n[{i}/{len(samples)}] Processing: {os.path.basename(img_path)}")
        try:
            explanation, _ = explain_xray(img_path)
            results.append({
                "image": os.path.basename(img_path),
                "label": label,
                "explanation": explanation
            })
            print(f"   ‚úÖ Generated {len(explanation.split())} words")
        except Exception as e:
            print(f"   ‚ùå Error: {e}")
            results.append({
                "image": os.path.basename(img_path),
                "label": label,
                "explanation": f"ERROR: {e}"
            })
    
    return results

# Run batch test
print("üî¨ Running batch evaluation...\n")
normal_results = batch_test(normal_images, label="NORMAL", num_samples=3)
pneumonia_results = batch_test(pneumonia_images, label="PNEUMONIA", num_samples=3)

all_results = normal_results + pneumonia_results
print(f"\n\n‚úÖ Batch test complete! Processed {len(all_results)} images.")

In [None]:
# Display all results
import pandas as pd

df = pd.DataFrame(all_results)
df['explanation_length'] = df['explanation'].apply(lambda x: len(x.split()))

print("\nüìä BATCH TEST RESULTS:")
print("="*80)
for _, row in df.iterrows():
    print(f"\nüñºÔ∏è  Image: {row['image']}")
    print(f"üè∑Ô∏è  Label: {row['label']}")
    print(f"üìù Explanation ({row['explanation_length']} words):")
    print(f"   {row['explanation'][:300]}..." if len(row['explanation']) > 300 else f"   {row['explanation']}")
    print("-"*80)

## üîü Compare Different Prompts

In [None]:
# Test different prompts on the same image
test_image = random.choice(pneumonia_images) if pneumonia_images else random.choice(normal_images)

prompts = [
    "Explain this chest X-ray:",
    "Describe the findings in this chest X-ray:",
    "What abnormalities are visible in this X-ray?",
    "Provide a radiological report for this chest X-ray:",
    "Is this X-ray normal or abnormal? Explain:"
]

print(f"üñºÔ∏è Testing image: {os.path.basename(test_image)}")
print("="*80)

# Display the image once
img = Image.open(test_image)
plt.figure(figsize=(6, 6))
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.title("Test Image")
plt.show()

# Test each prompt
for prompt in prompts:
    print(f"\nüìù Prompt: \"{prompt}\"")
    print("-"*60)
    explanation, _ = explain_xray(test_image, prompt=prompt)
    print(f"Response: {explanation}")

---
## üìä Summary

This notebook allows you to:
1. ‚úÖ Load your fine-tuned PaliGemma model
2. ‚úÖ Test on Kaggle chest X-ray dataset
3. ‚úÖ Upload and test custom images
4. ‚úÖ Batch evaluate multiple images
5. ‚úÖ Compare different prompts

**Next Steps:**
- Continue training if results are not satisfactory
- Try different prompts for better explanations
- Export model for deployment