# ExplainMyXray - MedGemma Training

## Cell 1: Install

In [None]:
!pip install -q -U transformers peft accelerate bitsandbytes datasets
!pip install -q pillow==10.4.0 pandas==2.2.2
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

## Cell 2: Login

In [None]:
import os
from huggingface_hub import login

# ‚ö†Ô∏è DO NOT hardcode your token here!
# Option 1: Set HF_TOKEN in a .env file (see .env.example)
# Option 2: Run 'huggingface-cli login' in terminal first
# Option 3: The login() call below will prompt you interactively

HF_TOKEN = os.environ.get("HF_TOKEN", "")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("‚úÖ Logged in via HF_TOKEN environment variable")
else:
    login()  # Interactive login
    print("‚úÖ Logged in interactively")

## Cell 3: Config

In [None]:
import torch
from transformers import BitsAndBytesConfig
from peft import LoraConfig, TaskType

MODEL_ID = "google/paligemma-3b-pt-224"
OUTPUT_DIR = "./medgemma_lora_adapters"

BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

LORA_CONFIG = LoraConfig(
    r=16, lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM,
)

BATCH_SIZE, GRADIENT_ACCUMULATION = 2, 4
LEARNING_RATE, NUM_EPOCHS, MAX_LENGTH = 2e-4, 3, 512
print("Config ready")

## Cell 4: Load Model

In [None]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
import os

# ========== CHECKPOINT FOR RESUMING (with fallback) ==========
CHECKPOINT_PRIMARY = "/content/drive/MyDrive/ExplainMyXray_Models/interrupted_checkpoint"
CHECKPOINT_FALLBACK = "/content/drive/MyDrive/medgemma_advanced_lora/checkpoint-250"

# Try primary first, then fallback
if os.path.exists(CHECKPOINT_PRIMARY):
    CHECKPOINT_PATH = CHECKPOINT_PRIMARY
    print(f"‚úÖ Found primary checkpoint: {CHECKPOINT_PATH}")
elif os.path.exists(CHECKPOINT_FALLBACK):
    CHECKPOINT_PATH = CHECKPOINT_FALLBACK
    print(f"‚úÖ Found fallback checkpoint: {CHECKPOINT_PATH}")
else:
    CHECKPOINT_PATH = None
    print("‚ö†Ô∏è No checkpoint found - will start fresh")
# ==============================================================

print("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)

print("Loading model in 4-bit (2-3 min)...")
model = PaliGemmaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    quantization_config=BNB_CONFIG,
    token=HF_TOKEN,
)

model = prepare_model_for_kbit_training(model)

if CHECKPOINT_PATH:
    print(f"üîÑ Loading LoRA adapters from checkpoint: {CHECKPOINT_PATH}")
    try:
        model = PeftModel.from_pretrained(model, CHECKPOINT_PATH, is_trainable=True)
        print("‚úÖ Checkpoint adapters loaded successfully!")
    except Exception as e:
        print(f"‚ùå Failed to load checkpoint: {e}")
        print("üÜï Falling back to fresh LoRA adapters")
        model = get_peft_model(model, LORA_CONFIG)
else:
    print("üÜï Starting with fresh LoRA adapters")
    model = get_peft_model(model, LORA_CONFIG)

model.print_trainable_parameters()
print("Model ready!")

## Cell 5: Load CSV

In [None]:
import pandas as pd
df = pd.read_csv("/content/chest_x_ray_images_labels_sample.csv")
print(f"Loaded {len(df)} samples")
df.head()

## Cell 6: Upload Images

In [None]:
import os
from google.colab import files
os.makedirs("/content/images", exist_ok=True)
print("Upload X-ray images:")
for name, content in files.upload().items():
    open(f"/content/images/{name}", "wb").write(content)
print("Done")

## Cell 7: Create Dataset (SIMPLIFIED)

In [None]:
from PIL import Image
from torch.utils.data import Dataset, random_split
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")

class XrayDataset(Dataset):
    def __init__(self, df, img_dir, proc):
        self.df = df.dropna(subset=["ImageID", "Report"]).reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.proc = proc
        # Filter to only samples with images that exist
        valid_rows = []
        for i, row in self.df.iterrows():
            img_id = str(row["ImageID"])
            p = self.img_dir / img_id
            if not p.exists():
                matches = list(self.img_dir.glob(f"{img_id.split('.')[0]}.*"))
                if matches:
                    valid_rows.append(i)
            else:
                valid_rows.append(i)
        self.df = self.df.iloc[valid_rows].reset_index(drop=True)
        print(f"Valid samples with images: {len(self.df)}")
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img_id = str(row["ImageID"])
        p = self.img_dir / img_id
        if not p.exists():
            matches = list(self.img_dir.glob(f"{img_id.split('.')[0]}.*"))
            p = matches[0]
        img = Image.open(p).convert("RGB")
        
        # Simple prompt + answer format
        prompt = "describe this chest xray:"
        answer = str(row["Report"]).strip()
        
        # Process with suffix (the answer we want to predict)
        model_inputs = self.proc(
            text=prompt,
            images=img,
            suffix=answer,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512,
        )
        
        return {k: v.squeeze(0) for k, v in model_inputs.items()}

ds = XrayDataset(df, "/content/images", processor)
if len(ds) > 1:
    train_size = max(1, int(0.9 * len(ds)))
    val_size = len(ds) - train_size
    train_ds, val_ds = random_split(ds, [train_size, val_size])
else:
    train_ds = val_ds = ds
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

# Test one sample
sample = ds[0]
print(f"Sample keys: {sample.keys()}")
print(f"input_ids shape: {sample['input_ids'].shape}")

## Cell 8: Training

In [None]:
from dataclasses import dataclass
from typing import Dict, List
from transformers import Trainer, TrainingArguments
import os

@dataclass
class Collator:
    def __call__(self, f: List[Dict]) -> Dict[str, torch.Tensor]:
        return {k: torch.stack([x[k] for x in f]) for k in f[0]}

# ========== RESUME FROM CHECKPOINT CONFIG (with fallback) ==========
CHECKPOINT_PRIMARY = "/content/drive/MyDrive/ExplainMyXray_Models/interrupted_checkpoint"
CHECKPOINT_FALLBACK = "/content/drive/MyDrive/medgemma_advanced_lora/checkpoint-250"

# Try primary first, then fallback
if os.path.exists(CHECKPOINT_PRIMARY):
    RESUME_CHECKPOINT = CHECKPOINT_PRIMARY
    print(f"‚úÖ Will resume from primary checkpoint: {RESUME_CHECKPOINT}")
elif os.path.exists(CHECKPOINT_FALLBACK):
    RESUME_CHECKPOINT = CHECKPOINT_FALLBACK
    print(f"‚úÖ Will resume from fallback checkpoint: {RESUME_CHECKPOINT}")
else:
    RESUME_CHECKPOINT = None
    print("‚ö†Ô∏è No checkpoint found - starting fresh training")
# ====================================================================

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    bf16=True,
    save_strategy="steps",
    save_steps=100,  # Save more frequently to avoid losing progress
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    report_to="none",
    dataloader_pin_memory=False,
    gradient_checkpointing=True,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=Collator(),
)

trainer.train(resume_from_checkpoint=RESUME_CHECKPOINT)
print("Training...")

## Cell 9: Save & Download

In [None]:
model.save_pretrained(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
import shutil
from google.colab import files
shutil.make_archive("medgemma_lora", "zip", OUTPUT_DIR)
files.download("medgemma_lora.zip")
print("Downloaded!")

## Cell 10: Test

In [None]:
from PIL import Image
from pathlib import Path
imgs = list(Path("/content/images").glob("*.*"))
if imgs:
    img = Image.open(imgs[0]).convert("RGB")
    inp = processor(images=img, text="describe this chest xray:", return_tensors="pt").to("cuda")
    with torch.no_grad():
        out = model.generate(**inp, max_new_tokens=50, do_sample=True)
    print(processor.decode(out[0], skip_special_tokens=True))