# ExplainMyXray v2 — MedGemma-4B + PadChest (RTX 4080 Laptop)

**Optimized for RTX 4080 Laptop 12 GB VRAM** (compute capability 8.9 → bfloat16).

### Architecture
| Spec | Value |
|------|-------|
| Base Model | google/medgemma-4b-it (SigLIP encoder + Gemma 3 decoder) |
| Fine-tuning | QLoRA NF4, LoRA r=64, alpha=128, all-linear + lm_head |
| Precision | **bf16** (RTX 4080 supports bfloat16 natively) |
| Dataset | BIMCV PadChest — 160K+ CXR, 174 findings, 104 locations |
| Image size | 512x512 (padded to square, LANCZOS, CLAHE) |
| Data access | Google Drive for Desktop (streams on-demand) or local |

### Preprocessing Pipeline (from analysis of 200+ PadChest images)
1. 16-bit I;16 -> 8-bit (percentile 1st/99th)
2. Auto-crop dark edges (33% of images have artifacts)
3. Pad to square (aspect ratios vary 0.8-1.22)
4. LANCZOS resize to 512x512
5. CLAHE (local contrast, clip=2.0)
6. Sharpen 1.2x
7. Convert to RGB (for SigLIP encoder)

## Phase 1 — Environment Setup

In [None]:
# ============================================================
# Cell 1: Install Dependencies (RTX 4080 compatible)
# ============================================================
import subprocess, sys, os, pathlib, platform

packages = [
    "transformers>=4.52.0",    # MedGemma support
    "trl>=0.17.0",             # SFTTrainer
    "peft>=0.15.0",            # LoRA/QLoRA
    "accelerate>=1.5.0",       # device_map
    "bitsandbytes>=0.44.0",    # 4-bit quantization
    "datasets>=3.5.0",         # HuggingFace datasets
    "evaluate",                # Metrics
    "scikit-learn",            # Sklearn metrics
    "Pillow>=10.0",            # Image loading
    "gdown",                   # Google Drive downloads
    "opencv-python",           # CLAHE (with GUI for desktop)
    "tensorboard",             # Training logs
]

for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg],
                          stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

import torch, transformers, trl, peft
print(f'torch={torch.__version__}, transformers={transformers.__version__}, '
      f'trl={trl.__version__}, peft={peft.__version__}')
print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"}')
print("Dependencies installed.")

In [None]:
# ============================================================
# Cell 2: Auto-Detect Google Drive for Desktop (data source)
# ============================================================
# PadChest data can come from:
#   1. Google Drive for Desktop (recommended — streams on-demand)
#   2. Local download (needs ~300 GB)
#   3. Partial download via gdown
# ============================================================

def _scan_gdrive_desktop():
    """Find PadChest on Google Drive for Desktop (Windows/macOS/Linux)."""
    system = platform.system()
    if system == 'Windows':
        for letter in 'GDEFHIJKLMNOPQRSTUVWXYZ':
            for name in ['Google Drive/My Drive/Padchest',
                         'My Drive/Padchest',
                         'GoogleDrive/My Drive/Padchest']:
                candidate = f'{letter}:/{name}'
                if os.path.isdir(candidate): return candidate
    elif system == 'Darwin':
        cloud = pathlib.Path.home() / 'Library' / 'CloudStorage'
        if cloud.exists():
            for folder in sorted(cloud.iterdir()):
                if folder.name.startswith('GoogleDrive'):
                    c = folder / 'My Drive' / 'Padchest'
                    if c.is_dir(): return str(c)
        legacy = '/Volumes/GoogleDrive/My Drive/Padchest'
        if os.path.isdir(legacy): return legacy
    elif system == 'Linux':
        home = pathlib.Path.home()
        for cand in [
            home / 'Google Drive' / 'My Drive' / 'Padchest',
            home / 'google-drive' / 'My Drive' / 'Padchest',
            home / 'gdrive' / 'My Drive' / 'Padchest',
            pathlib.Path('/mnt/google-drive/My Drive/Padchest'),
        ]:
            if cand.is_dir(): return str(cand)
    # Check Colab
    colab = '/content/drive/MyDrive/Padchest'
    if os.path.isdir(colab): return colab
    return None

GDRIVE_PADCHEST = _scan_gdrive_desktop()
if GDRIVE_PADCHEST:
    print(f'Google Drive detected: {GDRIVE_PADCHEST}')
    print('Data streams on-demand — no local download needed.')
else:
    print('Google Drive for Desktop not found.')
    print('Options: 1) Install GDrive Desktop, 2) Set paths manually in Config')

In [None]:
# ============================================================
# Cell 3: Imports
# ============================================================
import os, ast, random, warnings, gc, time
from pathlib import Path
from typing import Any, Optional, Tuple
from collections import Counter
from dataclasses import dataclass

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
from PIL import Image, ImageFilter, ImageEnhance, ImageOps
import cv2

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
    pipeline,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
from datasets import Dataset, DatasetDict

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", message=".*PaliGemmaProcessor.*")

print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB")

In [None]:
# ============================================================
# Cell 4: HuggingFace Authentication
# ============================================================
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
    try:
        from kaggle_secrets import UserSecretsClient
        hf_token = UserSecretsClient().get_secret("HF_TOKEN")
    except: pass
if not hf_token:
    from huggingface_hub import notebook_login
    notebook_login()
else:
    login(token=hf_token)
    print("Logged in to HuggingFace.")

In [None]:
# ============================================================
# Cell 5: GPU Configuration — RTX 4080 (bf16 native)
# ============================================================
if not torch.cuda.is_available(): raise RuntimeError('No GPU!')

cc = torch.cuda.get_device_capability(0)
USE_BF16 = cc[0] >= 8  # RTX 4080 = cc 8.9 → bf16 OK
COMPUTE_DTYPE = torch.bfloat16 if USE_BF16 else torch.float16

print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'CC: {cc[0]}.{cc[1]}, Precision: {"bf16" if USE_BF16 else "fp16"}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

## Phase 2 — Configuration

In [None]:
# ============================================================
# Cell 6: Master Configuration — RTX 4080 12 GB VRAM
# ============================================================

@dataclass
class Config:
    model_id: str = "google/medgemma-4b-it"
    use_full_padchest: bool = True

    # ---- Paths (auto-detected from Google Drive or manual) ----
    padchest_csv: str = ""
    padchest_images: str = ""
    output_dir: str = "./explainmyxray-v2-medgemma-padchest"

    # ---- QLoRA ----
    lora_r: int = 64       # High rank for 174 findings x 104 locations
    lora_alpha: int = 128
    lora_dropout: float = 0.05
    load_in_4bit: bool = True
    bnb_4bit_quant_type: str = "nf4"
    bnb_4bit_use_double_quant: bool = True

    # ---- Training (tuned for 12 GB VRAM single GPU) ----
    num_train_epochs: int = 8
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    gradient_accumulation_steps: int = 16   # eff batch = 1 * 1 * 16 = 16
    learning_rate: float = 5e-5
    warmup_ratio: float = 0.1
    max_grad_norm: float = 0.3
    lr_scheduler_type: str = "cosine_with_restarts"
    logging_steps: int = 10
    eval_steps: int = 50
    save_steps: int = 100
    max_seq_length: int = 768
    label_smoothing_factor: float = 0.05

    # ---- Image Preprocessing (same as Kaggle notebook) ----
    image_size: int = 512
    apply_clahe: bool = True
    clahe_clip_limit: float = 2.0
    clahe_grid_size: int = 8
    auto_crop_edges: bool = True
    edge_crop_threshold: float = 0.05
    pad_to_square: bool = True
    pad_color: int = 0

    # ---- Splits ----
    train_ratio: float = 0.90
    val_ratio: float = 0.05
    test_ratio: float = 0.05
    max_samples: int = 0
    use_curriculum: bool = True

cfg = Config()

# ---- Auto-set paths from Google Drive detection ----
if GDRIVE_PADCHEST and cfg.use_full_padchest:
    csv_candidates = [
        os.path.join(GDRIVE_PADCHEST, 'PADCHEST_chest_x_ray_images_labels_160K.csv'),
        os.path.join(GDRIVE_PADCHEST, 'padchest_labels.csv'),
    ]
    for c in csv_candidates:
        if os.path.isfile(c):
            cfg.padchest_csv = c; break
    img_candidates = [
        os.path.join(GDRIVE_PADCHEST, 'images'),
        os.path.join(GDRIVE_PADCHEST, 'PadChest', 'images'),
    ]
    for c in img_candidates:
        if os.path.isdir(c):
            cfg.padchest_images = c; break

# ---- Fallback: local dataset ----
if not cfg.padchest_csv:
    # Check common local paths (adjust to your setup)
    local_candidates = [
        './data/padchest/PADCHEST_chest_x_ray_images_labels_160K.csv',
        os.path.expanduser('~/Datasets/Padchest/PADCHEST_chest_x_ray_images_labels_160K.csv'),
        os.path.expanduser('~/data/padchest/PADCHEST_chest_x_ray_images_labels_160K.csv'),
    ]
    for local_csv in local_candidates:
        if os.path.isfile(local_csv):
            cfg.padchest_csv = local_csv
            cfg.padchest_images = os.path.join(os.path.dirname(local_csv), 'images')
            print(f'Using local dataset: {local_csv}')
            break

print(f'CSV: {cfg.padchest_csv} (exists: {os.path.isfile(cfg.padchest_csv) if cfg.padchest_csv else False})')
print(f'Images: {cfg.padchest_images} (exists: {os.path.isdir(cfg.padchest_images) if cfg.padchest_images else False})')
print(f'LoRA r={cfg.lora_r}, LR={cfg.learning_rate}, Epochs={cfg.num_train_epochs}')
print(f'Effective batch: {cfg.per_device_train_batch_size} x {cfg.gradient_accumulation_steps} = {cfg.per_device_train_batch_size * cfg.gradient_accumulation_steps}')

## Phase 3 — Data-Driven Image Preprocessing Pipeline

In [None]:
# ============================================================
# Cell 7: Medical X-Ray Preprocessing Pipeline (7 stages)
# ============================================================
# Same pipeline as Kaggle notebook — designed from analysis of
# 200+ PadChest images. This ensures consistent preprocessing
# regardless of where the model is trained.
# ============================================================

def _convert_16bit_to_8bit(img):
    """Convert 16-bit images to 8-bit via percentile normalization."""
    arr = np.array(img, dtype=np.float32)
    if arr.ndim == 3: arr = arr[:, :, 0]
    p1, p99 = np.percentile(arr, [1, 99])
    if p99 - p1 < 1: p1, p99 = arr.min(), arr.max()
    if p99 - p1 < 1: return np.zeros(arr.shape, dtype=np.uint8)
    return np.clip((arr - p1) / (p99 - p1) * 255, 0, 255).astype(np.uint8)

def _auto_crop_dark_edges(gray, threshold_ratio=0.05):
    """Crop dark edge artifacts (33% of PadChest images)."""
    h, w = gray.shape
    if h < 100 or w < 100: return gray
    center = gray[h//4:3*h//4, w//4:3*w//4]
    center_mean = center.mean()
    if center_mean < 10: return gray
    threshold = center_mean * threshold_ratio
    top, bottom, left, right = 0, h, 0, w
    for row in range(h // 6):
        if gray[row, w//4:3*w//4].mean() < threshold: top = row + 1
        else: break
    for row in range(h - 1, h - h // 6, -1):
        if gray[row, w//4:3*w//4].mean() < threshold: bottom = row
        else: break
    for col in range(w // 6):
        if gray[h//4:3*h//4, col].mean() < threshold: left = col + 1
        else: break
    for col in range(w - 1, w - w // 6, -1):
        if gray[h//4:3*h//4, col].mean() < threshold: right = col
        else: break
    if (bottom - top) < h * 0.6 or (right - left) < w * 0.6: return gray
    return gray[top:bottom, left:right]

def _pad_to_square(gray, pad_value=0):
    """Pad to square preserving aspect ratio."""
    h, w = gray.shape
    if h == w: return gray
    t = max(h, w)
    padded = np.full((t, t), pad_value, dtype=gray.dtype)
    padded[(t-h)//2:(t-h)//2+h, (t-w)//2:(t-w)//2+w] = gray
    return padded

def preprocess_medical_image(image_path, cfg=cfg):
    """Full 7-stage preprocessing pipeline."""
    try:
        img = Image.open(image_path)
        if img.mode in ('I;16', 'I', 'I;16B', 'I;16L'):
            gray = _convert_16bit_to_8bit(img)
        elif img.mode == 'L': gray = np.array(img, dtype=np.uint8)
        elif img.mode in ('RGB', 'RGBA'): gray = np.array(img.convert('L'), dtype=np.uint8)
        else: gray = _convert_16bit_to_8bit(img)
        if cfg.auto_crop_edges: gray = _auto_crop_dark_edges(gray, cfg.edge_crop_threshold)
        if cfg.pad_to_square: gray = _pad_to_square(gray, cfg.pad_color)
        img_pil = Image.fromarray(gray, mode='L').resize((cfg.image_size, cfg.image_size), Image.LANCZOS)
        gray = np.array(img_pil, dtype=np.uint8)
        if cfg.apply_clahe:
            clahe = cv2.createCLAHE(clipLimit=cfg.clahe_clip_limit,
                                   tileGridSize=(cfg.clahe_grid_size, cfg.clahe_grid_size))
            gray = clahe.apply(gray)
        img_pil = Image.fromarray(gray, mode='L')
        img_pil = ImageEnhance.Sharpness(img_pil).enhance(1.2)
        return img_pil.convert('RGB')
    except Exception as e:
        print(f'[WARN] Preprocess failed: {os.path.basename(str(image_path))}: {e}')
        return None

print('Pipeline: 16bit->8bit -> crop -> pad -> resize -> CLAHE -> sharpen -> RGB')

In [None]:
# ============================================================
# Cell 8: Test Preprocessing Pipeline
# ============================================================
test_dir = None
if cfg.padchest_images and os.path.isdir(cfg.padchest_images):
    for sub in sorted(os.listdir(cfg.padchest_images)):
        sp = os.path.join(cfg.padchest_images, sub)
        if os.path.isdir(sp) and len(os.listdir(sp)) > 0:
            test_dir = sp; break
    if test_dir is None and any(f.endswith('.png') for f in os.listdir(cfg.padchest_images)):
        test_dir = cfg.padchest_images

if test_dir:
    test_files = sorted([f for f in os.listdir(test_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))])[:6]
    if test_files:
        ncols = min(6, len(test_files))
        fig, axes = plt.subplots(2, ncols, figsize=(4*ncols, 8))
        if ncols == 1: axes = axes.reshape(2, 1)
        fig.suptitle('Preprocessing: Raw vs Processed', fontsize=14, fontweight='bold')
        for i, f in enumerate(test_files):
            path = os.path.join(test_dir, f)
            raw = Image.open(path)
            raw_arr = np.array(raw, dtype=np.float32)
            axes[0, i].imshow(raw_arr, cmap='gray')
            axes[0, i].set_title(f'{raw.size[0]}x{raw.size[1]} {raw.mode}', fontsize=8)
            axes[0, i].axis('off')
            processed = preprocess_medical_image(path)
            if processed:
                axes[1, i].imshow(processed)
                axes[1, i].set_title(f'{processed.size[0]}x{processed.size[1]} RGB', fontsize=8)
            axes[1, i].axis('off')
        plt.tight_layout(); plt.show()
else:
    print('No images found to test pipeline.')

In [None]:
# ============================================================
# Cell 9: Load PadChest CSV & Parse Labels
# ============================================================
df = pd.read_csv(cfg.padchest_csv)
print(f"Raw: {len(df)} rows, {len(df.columns)} columns")

def safe_parse_list(val):
    if pd.isna(val) or str(val).strip() in ["","[]","nan","None"]: return []
    try:
        parsed = ast.literal_eval(str(val))
        if isinstance(parsed, list):
            flat = []
            for item in parsed:
                if isinstance(item, list): flat.extend(str(x).strip() for x in item)
                else: flat.append(str(item).strip())
            return [f for f in flat if f and f != 'nan']
        return [str(parsed).strip()]
    except: return [str(val).strip()]

df["labels_parsed"] = df["Labels"].apply(safe_parse_list)
df["localizations_parsed"] = df["Localizations"].apply(safe_parse_list)
df["labels_locs_parsed"] = df["LabelsLocalizationsBySentence"].apply(safe_parse_list)

def split_findings_locations(items):
    findings, locations = [], []
    for item in items:
        c = item.strip()
        if c.startswith("loc "): locations.append(c[4:].strip())
        elif c not in ["exclude","","nan"]: findings.append(c)
    return findings, locations

df["findings"], df["locations"] = zip(*df["labels_locs_parsed"].apply(split_findings_locations))
df["num_findings"] = df["findings"].apply(len)
print(f"Parsed: {len(df)} rows, findings range {df["num_findings"].min()}-{df["num_findings"].max()}")

In [None]:
# ============================================================
# Cell 10: Filter Valid Images
# ============================================================
def resolve_image_path(row):
    img_name = row["ImageID"]
    if cfg.use_full_padchest:
        img_dir_num = row.get("ImageDir")
        if pd.notna(img_dir_num):
            return os.path.join(cfg.padchest_images, str(int(img_dir_num)), img_name)
        for sub in range(38):
            c = os.path.join(cfg.padchest_images, str(sub), img_name)
            if os.path.exists(c): return c
    return os.path.join(cfg.padchest_images, img_name)

df["image_path"] = df.apply(resolve_image_path, axis=1)
df["image_exists"] = df["image_path"].apply(os.path.exists)
df_valid = df[df["image_exists"]].copy()
print(f"Available: {len(df_valid)} / {len(df)}")

if len(df_valid) == 0:
    raise FileNotFoundError("No images found! Check paths in Config.")

df_valid = df_valid[df_valid["num_findings"] > 0].copy()
print(f"With findings: {len(df_valid)}")
if cfg.max_samples > 0:
    df_valid = df_valid.sample(n=min(cfg.max_samples, len(df_valid)), random_state=SEED)

all_findings = [f for fl in df_valid["findings"] for f in fl]
finding_counts = Counter(all_findings)
print(f"Unique findings: {len(finding_counts)}")
for f, c in finding_counts.most_common(15):
    print(f"  {f}: {c}")

In [None]:
# ============================================================
# Cell 11: Structured Prompt Engineering
# ============================================================
SYSTEM_PROMPT = (
    "You are an expert board-certified radiologist AI analyzing chest X-rays "
    "from the BIMCV PadChest dataset. Produce a structured radiology report "
    "following this exact format:\n\n"
    "FINDINGS:\n"
    "- State each finding on a separate line\n"
    "- Include anatomical location in parentheses when known\n"
    "- Be specific: use standard radiological terminology\n"
    "- If no abnormality: state 'No significant abnormalities detected'\n\n"
    "LOCATIONS:\n"
    "- List all affected anatomical regions\n\n"
    "IMPRESSION:\n"
    "- Provide a concise clinical summary\n"
    "- Note if correlation with prior studies is recommended\n\n"
    "Be systematic: check lung fields, mediastinum, cardiac silhouette, "
    "diaphragm, pleural spaces, and bony thorax."
)

def build_user_prompt(view_position='PA'):
    v = view_position if pd.notna(view_position) and view_position else "unknown"
    return f"Analyze this chest X-ray (projection: {v}). Provide FINDINGS, LOCATIONS, and IMPRESSION."

def build_assistant_response(findings, locations):
    fu = list(dict.fromkeys(findings))
    lu = list(dict.fromkeys(locations))
    abn = [f for f in fu if f.lower() not in ["normal","unchanged","exclude","nan",""]]
    nl = chr(10)
    if not abn:
        fs = '- No significant abnormalities detected'
        imp = 'Normal chest X-ray. No acute cardiopulmonary disease.'
    else:
        lines = []
        for f in abn:
            matched = [l for l in lu if l]
            loc_joined = ', '.join(matched[:3])
            loc_str = f' ({loc_joined})' if matched else ''
            lines.append(f'- {f.capitalize()}{loc_str}')
        fs = nl.join(lines)
        if len(abn) == 1:
            imp = f'{abn[0].capitalize()} identified. Clinical correlation recommended.'
        else:
            top = ', '.join(a.capitalize() for a in abn[:4])
            imp = f'Multiple findings: {top}. Clinical correlation and follow-up recommended.'
    r = f'FINDINGS:{nl}{fs}{nl}{nl}'
    locs_str = ', '.join(lu) if lu else 'Not specified'
    r += f'LOCATIONS:{nl}{locs_str}{nl}{nl}'
    r += f'IMPRESSION:{nl}{imp}'
    return r

s = df_valid.iloc[0]
print("=== Sample ===")
print(f"Findings: {s["findings"]}")
print(f"\nOutput:\n{build_assistant_response(s["findings"], s["locations"])}")

In [None]:
# ============================================================
# Cell 12: Curriculum Learning
# ============================================================
def compute_difficulty(row):
    score = 0
    findings = row["findings"]
    normal_labels = {"normal","unchanged","exclude","nan",""}
    abn = [f for f in findings if f.lower() not in normal_labels]
    score += len(abn) * 2
    score += len(row["locations"])
    for f in abn:
        freq = finding_counts.get(f, 0)
        if freq <= 5: score += 5
        elif freq <= 20: score += 3
        elif freq <= 50: score += 1
    return score

df_valid["difficulty"] = df_valid.apply(compute_difficulty, axis=1)
if cfg.use_curriculum:
    df_valid = df_valid.sort_values("difficulty").reset_index(drop=True)
    print('Curriculum: easy -> hard')
else:
    df_valid = df_valid.sample(frac=1, random_state=SEED).reset_index(drop=True)

In [None]:
# ============================================================
# Cell 13: Convert to HuggingFace Dataset (with preprocessing)
# ============================================================
def row_to_example(row):
    try:
        image = preprocess_medical_image(row["image_path"], cfg)
        if image is None: return None
    except Exception as e:
        print(f"  Skip: {e}")
        return None
    v = row.get("Projection", "PA")
    messages = [
        {"role":"system","content":[{"type":"text","text":SYSTEM_PROMPT}]},
        {"role":"user","content":[{"type":"image"},{"type":"text","text":build_user_prompt(v)}]},
        {"role":"assistant","content":[{"type":"text","text":build_assistant_response(row["findings"],row["locations"])}]},
    ]
    return {"image":image,"messages":messages,"findings":row["findings"],"locations":row["locations"],"image_id":row["ImageID"]}

print("Processing images...")
t0 = time.time()
examples, failed = [], 0
for i, (_, row) in enumerate(df_valid.iterrows()):
    ex = row_to_example(row)
    if ex: examples.append(ex)
    else: failed += 1
    if (i+1) % 500 == 0:
        print(f'  {i+1}/{len(df_valid)} | ok={len(examples)} fail={failed}')

print(f'Done: {len(examples)}/{len(df_valid)} in {time.time()-t0:.0f}s ({failed} skipped)')

In [None]:
# ============================================================
# Cell 14: Train / Val / Test Split
# ============================================================
n = len(examples)
n_train = int(n * cfg.train_ratio)
n_val = int(n * cfg.val_ratio)
train_ex = examples[:n_train]
val_ex = examples[n_train:n_train+n_val]
test_ex = examples[n_train+n_val:]
random.shuffle(val_ex); random.shuffle(test_ex)

def to_ds(exs):
    return Dataset.from_dict({
        "image":[e["image"] for e in exs],
        "messages":[e["messages"] for e in exs],
        "findings":[e["findings"] for e in exs],
        "locations":[e["locations"] for e in exs],
        "image_id":[e["image_id"] for e in exs],
    })

dataset = DatasetDict({"train":to_ds(train_ex),"validation":to_ds(val_ex),"test":to_ds(test_ex)})
for split, ds in dataset.items(): print(f'  {split}: {len(ds)}')

## Phase 4 — Model Loading (MedGemma-4B QLoRA)

In [None]:
# ============================================================
# Cell 15: Load MedGemma-4B-it with QLoRA
# ============================================================
# RTX 4080 = cc 8.9 → use bfloat16
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=COMPUTE_DTYPE,
    bnb_4bit_quant_storage=COMPUTE_DTYPE,
)

print(f'Loading {cfg.model_id} with 4-bit NF4 ({COMPUTE_DTYPE})...')
model = AutoModelForImageTextToText.from_pretrained(
    cfg.model_id,
    quantization_config=bnb_config,
    attn_implementation="eager",  # Required for MedGemma
    torch_dtype=COMPUTE_DTYPE,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(cfg.model_id)
processor.tokenizer.padding_side = "right"
print(f'Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB VRAM')

In [None]:
# ============================================================
# Cell 16: LoRA Config — r=64
# ============================================================
peft_config = LoraConfig(
    r=cfg.lora_r,
    lora_alpha=cfg.lora_alpha,
    lora_dropout=cfg.lora_dropout,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head","embed_tokens"],
)
print(f'LoRA: r={cfg.lora_r}, alpha={cfg.lora_alpha}')

In [None]:
# ============================================================
# Cell 17: Data Collator
# ============================================================
def collate_fn(examples):
    texts, images = [], []
    for ex in examples:
        img = ex["image"]
        if not isinstance(img, Image.Image):
            img = Image.open(img).convert("RGB")
        images.append([img.convert("RGB")])
        text = processor.apply_chat_template(
            ex["messages"], add_generation_prompt=False, tokenize=False
        ).strip()
        texts.append(text)
    batch = processor(text=texts, images=images, return_tensors="pt",
                      padding=True, truncation=True, max_length=cfg.max_seq_length)
    labels = batch["input_ids"].clone()
    pid = processor.tokenizer.pad_token_id
    if pid is not None: labels[labels == pid] = -100
    labels[labels == 262144] = -100
    batch["labels"] = labels
    return batch

tb = collate_fn([dataset["train"][0]])
print(f"Collator OK. Input: {tb["input_ids"].shape}")

## Phase 5 — Training

In [None]:
# ============================================================
# Cell 18: Training Arguments (RTX 4080, 12 GB VRAM)
# ============================================================
from transformers import EarlyStoppingCallback

training_args = SFTConfig(
    output_dir=cfg.output_dir,
    num_train_epochs=cfg.num_train_epochs,
    per_device_train_batch_size=cfg.per_device_train_batch_size,
    per_device_eval_batch_size=cfg.per_device_eval_batch_size,
    gradient_accumulation_steps=cfg.gradient_accumulation_steps,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    optim="paged_adamw_8bit",
    learning_rate=cfg.learning_rate,
    warmup_ratio=cfg.warmup_ratio,
    max_grad_norm=cfg.max_grad_norm,
    lr_scheduler_type=cfg.lr_scheduler_type,
    weight_decay=0.01,
    label_smoothing_factor=cfg.label_smoothing_factor,
    bf16=USE_BF16,
    fp16=not USE_BF16,
    logging_steps=cfg.logging_steps,
    eval_strategy="steps",
    eval_steps=cfg.eval_steps,
    save_strategy="steps",
    save_steps=cfg.save_steps,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="tensorboard",
    logging_dir=os.path.join(cfg.output_dir, "logs"),
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"],
    dataloader_pin_memory=True,
    dataloader_num_workers=4,  # Desktop has more CPU cores
    max_seq_length=cfg.max_seq_length,
)

early_stop = EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.001)
eff = cfg.per_device_train_batch_size * cfg.gradient_accumulation_steps
print(f'Epochs: {cfg.num_train_epochs}, Effective batch: {eff}, LR: {cfg.learning_rate}')

In [None]:
# ============================================================
# Cell 19: Build Trainer & Start Training
# ============================================================
trainer = SFTTrainer(
    model=model, args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
    callbacks=[early_stop],
)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_p = sum(p.numel() for p in model.parameters())
print(f'Trainable: {trainable:,} / {total_p:,} ({100*trainable/total_p:.2f}%)')
print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')

print('Starting training...')
train_result = trainer.train()
print(f'Done! Loss: {train_result.training_loss:.4f}, Steps: {train_result.global_step}')

In [None]:
# ============================================================
# Cell 20: Save Model
# ============================================================
trainer.save_model(cfg.output_dir)
processor.save_pretrained(cfg.output_dir)
size_mb = sum(os.path.getsize(os.path.join(cfg.output_dir, f))
    for f in os.listdir(cfg.output_dir)
    if os.path.isfile(os.path.join(cfg.output_dir, f))) / 1e6
print(f'Saved: {cfg.output_dir} ({size_mb:.1f} MB)')
del model, trainer
torch.cuda.empty_cache(); gc.collect()
print(f'VRAM freed: {torch.cuda.memory_allocated()/1e9:.2f} GB')

## Phase 6 — Evaluation & Inference

In [None]:
# ============================================================
# Cell 21: Load Fine-Tuned Model for Inference
# ============================================================
ft_pipe = pipeline('image-text-to-text', model=cfg.output_dir,
                   torch_dtype=COMPUTE_DTYPE, device_map='auto')
ft_pipe.model.generation_config.do_sample = False
ft_pipe.model.generation_config.max_new_tokens = 384
ft_pipe.model.generation_config.temperature = 0.1
ft_pipe.tokenizer.padding_side = 'left'
print('Fine-tuned model loaded.')

In [None]:
# ============================================================
# Cell 22: Run Inference on Test Set
# ============================================================
def run_inference(pipe, ds, max_n=None):
    results = []
    n = min(len(ds), max_n) if max_n else len(ds)
    for i in range(n):
        ex = ds[i]
        msgs = [
            {"role":"system","content":[{"type":"text","text":SYSTEM_PROMPT}]},
            {"role":"user","content":[{"type":"image"},
             {"type":"text","text":build_user_prompt("PA")}]},
        ]
        try:
            out = pipe(text=msgs, images=[ex['image']], return_full_text=False)
            pred = out[0]['generated_text']
        except Exception as e: pred = f'ERROR: {e}'
        results.append({
            'image_id': ex['image_id'], 'gt_findings': ex['findings'],
            'gt_locations': ex['locations'], 'prediction': pred,
        })
        if (i+1) % 5 == 0: print(f'  {i+1}/{n}')
    return results

print(f'Inference on {len(dataset["test"])} test samples...')
test_results = run_inference(ft_pipe, dataset['test'])

In [None]:
# ============================================================
# Cell 23: Comprehensive Evaluation
# ============================================================
def extract_findings_from_report(text):
    findings, in_findings = [], False
    for line in text.split('\n'):
        line = line.strip()
        if line.upper().startswith('FINDINGS'): in_findings = True; continue
        if line.upper().startswith(('LOCATIONS','IMPRESSION')): in_findings = False; continue
        if in_findings and line.startswith('- '):
            f = line.lstrip('- ').split('(')[0].strip().lower()
            if f and f not in ['nan','']: findings.append(f)
        elif in_findings and 'no significant' in line.lower():
            findings.append('normal')
    return findings if findings else ['normal']

exact_match, soft_match, total = 0, 0, 0
tp, fp, fn = Counter(), Counter(), Counter()
per_finding_tp, per_finding_total = Counter(), Counter()

for r in test_results:
    gt = set(f.lower().strip() for f in r['gt_findings'] if f.strip())
    pr = set(extract_findings_from_report(r['prediction']))
    if gt == pr: exact_match += 1
    if gt:
        overlap = len(gt & pr) / len(gt)
        if overlap >= 0.5: soft_match += 1
    else:
        if not pr or pr == {'normal'}: soft_match += 1
    total += 1
    for f in gt & pr: tp[f] += 1
    for f in pr - gt: fp[f] += 1
    for f in gt - pr: fn[f] += 1
    for f in gt: per_finding_total[f] += 1
    for f in gt & pr: per_finding_tp[f] += 1

ea = exact_match / total if total else 0
sa = soft_match / total if total else 0
ttp, tfp, tfn = sum(tp.values()), sum(fp.values()), sum(fn.values())
prec = ttp / (ttp + tfp) if ttp + tfp else 0
rec = ttp / (ttp + tfn) if ttp + tfn else 0
f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0

print('='*60)
print(f'  RESULTS ({total} test samples)')
print('='*60)
print(f'  Exact match:  {exact_match}/{total} ({ea*100:.1f}%)')
print(f'  Soft match:   {soft_match}/{total} ({sa*100:.1f}%)')
print(f'  Precision:    {prec:.3f}')
print(f'  Recall:       {rec:.3f}')
print(f'  F1:           {f1:.3f}')
print('='*60)
if sa >= 0.95: print(f'  TARGET MET: {sa*100:.1f}% >= 95%')
elif sa >= 0.90: print(f'  CLOSE: {sa*100:.1f}% (gap: {(0.95-sa)*100:.1f}%)')
else: print(f'  Gap: {(0.95-sa)*100:.1f}% to 95%')

print(f'\nPer-finding accuracy (top 20):')
for f in sorted(per_finding_total, key=lambda x: per_finding_total[x], reverse=True)[:20]:
    t = per_finding_tp.get(f, 0)
    n = per_finding_total[f]
    acc = t/n if n else 0
    bar = chr(9608) * int(acc * 20)
    print(f'  {f:30s}: {t:3d}/{n:3d} ({acc*100:5.1f}%) {bar}')

In [None]:
# ============================================================
# Cell 24: Visualize Predictions
# ============================================================
ANAT_REGIONS = {
    'right upper lobe': (0.05,0.05,0.35,0.30),
    'right middle lobe': (0.05,0.30,0.35,0.20),
    'right lower lobe': (0.05,0.50,0.35,0.30),
    'left upper lobe':  (0.55,0.05,0.35,0.30),
    'left lower lobe':  (0.55,0.50,0.35,0.30),
    'cardiac silhouette': (0.30,0.30,0.35,0.40),
    'cardiac': (0.30,0.30,0.35,0.40),
    'bilateral': (0.05,0.15,0.85,0.65),
    'hilar': (0.30,0.20,0.35,0.30),
    'mediastinum': (0.30,0.10,0.35,0.60),
}
COLORS = ['#FF4444','#4488FF','#FF8800','#FFCC00','#44FF88','#FF44FF','#88FFFF']

for i in range(min(5, len(test_results))):
    r = test_results[i]
    img = dataset['test'][i]['image']
    fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    axes[0].imshow(img); axes[0].set_title('Input'); axes[0].axis('off')
    axes[1].imshow(img)
    w, h = (img.size if hasattr(img,'size') else (img.shape[1],img.shape[0]))
    for j, loc in enumerate(r['gt_locations']):
        k = loc.lower().strip()
        if k in ANAT_REGIONS:
            rx,ry,rw,rh = ANAT_REGIONS[k]
            c = COLORS[j % len(COLORS)]
            axes[1].add_patch(Rectangle((rx*w,ry*h),rw*w,rh*h,lw=2,ec=c,fc=c,alpha=0.15))
    findings_str = ', '.join(r['gt_findings'][:3])
    axes[1].set_title(f'Ground Truth: {findings_str}'); axes[1].axis('off')
    axes[2].text(0.05,0.95,r['prediction'],transform=axes[2].transAxes,fontsize=8,
                va='top',fontfamily='monospace',wrap=True)
    axes[2].set_title('Prediction'); axes[2].axis('off')
    fig.suptitle(f'Test #{i+1}: {r["image_id"]}', fontsize=13, fontweight='bold')
    plt.tight_layout(); plt.show()

In [None]:
# ============================================================
# Cell 25: Interactive Prediction
# ============================================================
def predict_xray(image_path, view='PA'):
    """Analyze any chest X-ray with the fine-tuned model."""
    img = preprocess_medical_image(image_path, cfg)
    if img is None: img = Image.open(image_path).convert('RGB')
    msgs = [
        {'role':'system','content':[{'type':'text','text':SYSTEM_PROMPT}]},
        {'role':'user','content':[{'type':'image'},
         {'type':'text','text':build_user_prompt(view)}]},
    ]
    out = ft_pipe(text=msgs, images=[img], return_full_text=False)
    report = out[0]['generated_text']
    print(report)
    return report, img

print('predict_xray() ready.')
print('Usage: report, img = predict_xray("path/to/xray.png")')

## Summary

| Component | RTX 4080 | Kaggle T4x2 |
|-----------|----------|-------------|
| GPU | RTX 4080 12GB (bf16) | 2x T4 15GB (fp16) |
| LoRA | r=64, alpha=128 | r=64, alpha=128 |
| Batch | 1 x 16 accum = 16 | 1 x 2 GPU x 8 accum = 16 |
| Preprocessing | 7-stage identical | 7-stage identical |
| Data | GDrive Desktop / local | GDrive via gdown / Kaggle Dataset |