# MedGemma ALL Leukemia ‚Äì Dataset Prep (2 Kaggle datasets)
This notebook:
- Downloads 2 Kaggle datasets
- Loads maximum images (ALL=1, HEM=0)
- Checks image types, sizes, and simple background stats
- Deduplicates by file hash
- Creates stratified train/val/test splits


In [33]:
!pip install -q transformers>=4.45.0
!pip install -q peft>=0.13.0
!pip install -q accelerate>=0.34.0
!pip install -q bitsandbytes>=0.44.0
!pip install -q datasets>=3.0.0
!pip install -q pillow
!pip install -q tqdm
!pip install -q huggingface_hub

print("‚úÖ All packages installed!")

‚úÖ All packages installed!


## üîê Hugging Face login
**Important:** Do NOT hardcode your HF token in notebooks. Use an environment variable instead.
If you already pasted a token in a notebook/chat, revoke it on Hugging Face and create a new one.

In [34]:
from huggingface_hub import login
import os, getpass

# Option A: if you already set HF_TOKEN in Colab secrets / env:
token = os.environ.get("HF_TOKEN", "")

# Option B: prompt securely (recommended)
if not token:
    token = getpass.getpass("Paste your HuggingFace token (input hidden): ")

login(token=token)
print("‚úÖ Logged in!")


Paste your HuggingFace token (input hidden): ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
‚úÖ Logged in!


## üì¶ Download Kaggle datasets (kagglehub)

In [35]:
import kagglehub, os

# Dataset A: C-NMC Leukemia (ALL vs HEM)
path_a = kagglehub.dataset_download("andrewmvd/leukemia-classification")
print("‚úÖ Dataset A downloaded to:", path_a)
print("üìÇ Top-level entries:", os.listdir(path_a))

# Dataset B: Leukemia Image Dataset (ALL vs HEM)
path_b = kagglehub.dataset_download("rakibhasan3948/leukemia-image-dataset")
print("\n‚úÖ Dataset B downloaded to:", path_b)
print("üìÇ Top-level entries:", os.listdir(path_b))


Using Colab cache for faster access to the 'leukemia-classification' dataset.
‚úÖ Dataset A downloaded to: /kaggle/input/leukemia-classification
üìÇ Top-level entries: ['C-NMC_Leukemia']
Using Colab cache for faster access to the 'leukemia-image-dataset' dataset.

‚úÖ Dataset B downloaded to: /kaggle/input/leukemia-image-dataset
üìÇ Top-level entries: ['Leukemia-Image-Dataset']


## üß≠ Load maximum images + label mapping
- **ALL / all** ‚Üí `1` (Leukemia)
- **HEM / hem** ‚Üí `0` (Normal)

We also **deduplicate** by file hash.

In [36]:
import os
from pathlib import Path
from collections import Counter

IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def collect_images(root: Path):
    root = Path(root)
    out = []
    for dp, _, fns in os.walk(root):
        dp = Path(dp)
        for fn in fns:
            if Path(fn).suffix.lower() in IMG_EXTS:
                out.append(dp / fn)
    return out

def find_dir_anywhere(root: Path, target: str):
    """Find a folder named `target` anywhere under root, case-insensitive."""
    target = target.lower()
    for p in root.rglob("*"):
        if p.is_dir() and p.name.lower() == target:
            return p
    return None

def clean_pairs(images, labels, name="dataset"):
    """
    Safety cleaner:
      - trims to min length if mismatch
      - drops missing files
      - de-duplicates by full path (keeps first label; warns on conflicts)
    """
    images = list(images)
    labels = list(labels)

    if len(images) != len(labels):
        print(f"‚ö†Ô∏è {name} mismatch: images={len(images)}, labels={len(labels)} -> trimming to min")
        n = min(len(images), len(labels))
        images, labels = images[:n], labels[:n]

    # drop missing
    kept_imgs, kept_lbls = [], []
    missing = 0
    for p, y in zip(images, labels):
        if Path(p).exists():
            kept_imgs.append(str(p))
            kept_lbls.append(int(y))
        else:
            missing += 1
    if missing:
        print(f"‚ö†Ô∏è {name}: removed missing files: {missing}")

    # dedupe by path
    seen = {}
    dedup_imgs, dedup_lbls = [], []
    conflicts = 0
    for p, y in zip(kept_imgs, kept_lbls):
        if p in seen:
            if seen[p] != y:
                conflicts += 1
            continue
        seen[p] = y
        dedup_imgs.append(p)
        dedup_lbls.append(y)

    if conflicts:
        print(f"‚ö†Ô∏è {name}: duplicate path label conflicts: {conflicts} (kept first label)")

    print(f"‚úÖ {name}: final={len(dedup_imgs)}  Normal={sum(1 for v in dedup_lbls if v==0)}  Leukemia={sum(1 for v in dedup_lbls if v==1)}")
    return dedup_imgs, dedup_lbls

def safe_combine(datasets):
    """datasets = [(name, images, labels), ...] -> combined (images, labels)"""
    all_imgs, all_lbls = [], []
    for name, imgs, lbls in datasets:
        imgs2, lbls2 = clean_pairs(imgs, lbls, name=name)
        all_imgs.extend(imgs2)
        all_lbls.extend(lbls2)

    assert len(all_imgs) == len(all_lbls), "‚ùå Internal error: combined lengths still mismatch."
    print("\n‚úÖ Combined datasets (MAX images, cleaned)")
    print("Total images:", len(all_imgs))
    print("Normal(0):", sum(l==0 for l in all_lbls))
    print("Leukemia(1):", sum(l==1 for l in all_lbls))
    return all_imgs, all_lbls

# -------------------------
# Dataset A: C-NMC Leukemia
# training_data/fold_0,1,2/{all,hem}
# all -> 1 (Leukemia), hem -> 0 (Normal)
# -------------------------
def load_cnm_all_folds(path_a_root: str, folds=("fold_0","fold_1","fold_2")):
    base = Path(path_a_root) / "C-NMC_Leukemia" / "training_data"
    images, labels = [], []

    if not base.exists():
        print("‚ö†Ô∏è C-NMC training_data not found at:", base)
        print("Top-level:", [p.name for p in Path(path_a_root).iterdir()])
        return [], []

    for fold in folds:
        fold_dir = base / fold
        if not fold_dir.exists():
            print(f"‚ö†Ô∏è Missing fold folder: {fold_dir}")
            continue

        for cls, y in [("all", 1), ("hem", 0)]:
            cls_dir = fold_dir / cls
            if cls_dir.exists():
                imgs = collect_images(cls_dir)
                images.extend([str(p) for p in imgs])
                labels.extend([y] * len(imgs))
            else:
                print(f"‚ö†Ô∏è Missing class folder: {cls_dir}")

    print("\n‚úÖ Loaded Dataset A (C-NMC)")
    print("  Total:", len(images),
          " Normal:", sum(l==0 for l in labels),
          " Leukemia:", sum(l==1 for l in labels))
    return images, labels

# -------------------------
# Dataset B: Leukemia Image Dataset (Rakib)
# folders: ALL (cancer), HEM (normal) but may be nested
# ALL -> 1, HEM -> 0
# -------------------------
def load_rakib_all_hem(path_b_root: str):
    root = Path(path_b_root)

    all_dir = find_dir_anywhere(root, "ALL")
    hem_dir = find_dir_anywhere(root, "HEM")

    print("\nüîé Dataset B folder detection")
    print("  ALL dir:", all_dir)
    print("  HEM dir:", hem_dir)

    images, labels = [], []

    if all_dir is not None:
        imgs = collect_images(all_dir)
        images.extend([str(p) for p in imgs])
        labels.extend([1] * len(imgs))

    if hem_dir is not None:
        imgs = collect_images(hem_dir)
        images.extend([str(p) for p in imgs])
        labels.extend([0] * len(imgs))

    print("\n‚úÖ Loaded Dataset B (Rakib Hasan)")
    print("  Total:", len(images),
          " Normal:", sum(l==0 for l in labels),
          " Leukemia:", sum(l==1 for l in labels))

    if len(images) == 0:
        print("‚ö†Ô∏è Dataset B returned 0 images.")
        print("üëâ Top-level entries:", [p.name for p in root.iterdir()])

    return images, labels

# -------------------------
# LOAD BOTH DATASETS (clean + safe)
# -------------------------
a_images, a_labels = load_cnm_all_folds(path_a)
b_images, b_labels = load_rakib_all_hem(path_b)

all_images, all_labels = safe_combine([
    ("DatasetA_CNM", a_images, a_labels),
    ("DatasetB_Rakib", b_images, b_labels),
])



‚úÖ Loaded Dataset A (C-NMC)
  Total: 10661  Normal: 3389  Leukemia: 7272

üîé Dataset B folder detection
  ALL dir: /kaggle/input/leukemia-image-dataset/Leukemia-Image-Dataset/ALL
  HEM dir: /kaggle/input/leukemia-image-dataset/Leukemia-Image-Dataset/HEM

‚úÖ Loaded Dataset B (Rakib Hasan)
  Total: 6778  Normal: 3389  Leukemia: 3389
‚úÖ DatasetA_CNM: final=10661  Normal=3389  Leukemia=7272
‚úÖ DatasetB_Rakib: final=6778  Normal=3389  Leukemia=3389

‚úÖ Combined datasets (MAX images, cleaned)
Total images: 17439
Normal(0): 6778
Leukemia(1): 10661


## üñºÔ∏è Check image type, size, and simple background stats
We sample images for speed and report:
- format distribution
- width/height distribution
- mean border brightness (rough background indicator)


In [37]:
from PIL import Image
import numpy as np
import random
from pathlib import Path
from collections import Counter
from tqdm import tqdm

# Avoid PIL decompression bomb warnings for large images
Image.MAX_IMAGE_PIXELS = None

def border_mean_intensity(img: Image.Image, border=4):
    """
    Mean grayscale intensity of border pixels.
    0=black, 255=white.
    """
    # Convert to 8-bit grayscale safely (handles TIFF / 16-bit cases better)
    arr = np.array(img.convert("L"), dtype=np.uint8)
    h, w = arr.shape
    if h == 0 or w == 0:
        return float("nan")

    b = min(border, max(1, h // 8), max(1, w // 8))  # robust border width

    # Border without double-counting corners too much
    top = arr[:b, :]
    bottom = arr[-b:, :]
    left = arr[b:-b, :b] if h > 2*b else arr[:, :b]
    right = arr[b:-b, -b:] if h > 2*b else arr[:, -b:]

    border_pixels = np.concatenate([top.ravel(), bottom.ravel(), left.ravel(), right.ravel()])
    return float(border_pixels.mean())

def inspect_images(paths, sample_n=800, seed=42, border=4):
    """
    Inspects image format distribution, size distribution, and border intensity.
    """
    paths = list(paths)
    if len(paths) == 0:
        print("‚ö†Ô∏è No image paths provided.")
        return

    random.seed(seed)
    sample = paths if len(paths) <= sample_n else random.sample(paths, sample_n)

    formats, sizes, border_means = [], [], []
    failures = 0

    for p in tqdm(sample, desc=f"Inspecting {len(sample)} images"):
        try:
            p = str(p)
            with Image.open(p) as im:
                fmt = (im.format or Path(p).suffix.replace(".", "") or "unknown").lower()
                formats.append(fmt)
                sizes.append(im.size)  # (w,h)
                border_means.append(border_mean_intensity(im, border=border))
        except Exception as e:
            failures += 1

    if len(sizes) == 0:
        print(f"‚ùå All failed to open. Failures: {failures} / {len(sample)}")
        return

    w_list = [s[0] for s in sizes]
    h_list = [s[1] for s in sizes]
    b_list = [x for x in border_means if np.isfinite(x)]

    print("\n‚úÖ Inspected:", len(sample), " Failures:", failures)
    print("Formats:", Counter(formats).most_common(10))
    print("Width  min/median/max:", min(w_list), int(np.median(w_list)), max(w_list))
    print("Height min/median/max:", min(h_list), int(np.median(h_list)), max(h_list))

    if len(b_list) > 0:
        print("Border mean intensity (0=black,255=white) min/median/max:",
              round(float(np.min(b_list)), 2),
              round(float(np.median(b_list)), 2),
              round(float(np.max(b_list)), 2))
    else:
        print("Border mean intensity: N/A (no valid values)")

# ---- Run on your list of image paths
inspect_images(all_images, sample_n=800, seed=42, border=4)


Inspecting 800 images: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 800/800 [00:01<00:00, 574.71it/s]


‚úÖ Inspected: 800  Failures: 0
Formats: [('bmp', 800)]
Width  min/median/max: 450 450 450
Height min/median/max: 450 450 450
Border mean intensity (0=black,255=white) min/median/max: 0.0 0.0 2.11





## ‚úÇÔ∏è Stratified Train/Val/Test split
Default:
- Train 80%
- Val 10%
- Test 10%

In [38]:
from sklearn.model_selection import train_test_split
from collections import Counter

def split_train_val_test(image_paths, labels, seed=42, val_size=0.10, test_size=0.10):
    assert len(image_paths) == len(labels), "‚ùå image_paths and labels length mismatch!"

    # First split: Train vs (Val+Test)
    X_train, X_tmp, y_train, y_tmp = train_test_split(
        image_paths,
        labels,
        test_size=(val_size + test_size),
        random_state=seed,
        stratify=labels
    )

    # Second split: Val vs Test (split X_tmp into two parts)
    relative_test_size = test_size / (val_size + test_size)
    X_val, X_test, y_val, y_test = train_test_split(
        X_tmp,
        y_tmp,
        test_size=relative_test_size,
        random_state=seed,
        stratify=y_tmp
    )

    def summarize(name, y):
        c = Counter(map(int, y))
        print(f"{name}: {len(y):6d}  Normal(0)={c.get(0,0):6d}  Leukemia(1)={c.get(1,0):6d}")

    print("\n‚úÖ Split summary")
    summarize("Train", y_train)
    summarize("Val  ", y_val)
    summarize("Test ", y_test)

    # Overlap check
    s_train, s_val, s_test = set(X_train), set(X_val), set(X_test)
    print("\nüîç Overlap check")
    print("Train ‚à© Val :", len(s_train & s_val))
    print("Train ‚à© Test:", len(s_train & s_test))
    print("Val   ‚à© Test:", len(s_val & s_test))

    return X_train, y_train, X_val, y_val, X_test, y_test


# ‚úÖ Use your combined lists here:
train_images, train_labels, val_images, val_labels, test_images, test_labels = split_train_val_test(
    all_images, all_labels, seed=42, val_size=0.10, test_size=0.10
)



‚úÖ Split summary
Train:  13951  Normal(0)=  5422  Leukemia(1)=  8529
Val  :   1744  Normal(0)=   678  Leukemia(1)=  1066
Test :   1744  Normal(0)=   678  Leukemia(1)=  1066

üîç Overlap check
Train ‚à© Val : 0
Train ‚à© Test: 0
Val   ‚à© Test: 0


## üíæ Load model

These CSVs feed directly into your fine-tuning notebook.

In [39]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText

MODEL_ID = "google/medgemma-1.5-4b-it"

# GPU info
assert torch.cuda.is_available(), "CUDA not available"
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Enable TF32 for speed (safe)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Processor
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Model (NO flash attention)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},   # force full GPU
)

model.train()

print("\n‚úÖ Model loaded successfully")
print(f"Parameters: {model.num_parameters() / 1e9:.2f} B")


GPU: NVIDIA A100-SXM4-80GB
VRAM: 85.2 GB


Loading weights:   0%|          | 0/883 [00:00<?, ?it/s]


‚úÖ Model loaded successfully
Parameters: 4.30 B


In [40]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32,                 # stronger than 16
    lora_alpha=64,        # usually 2*r
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 23,797,760 || all params: 4,323,877,232 || trainable%: 0.5504


In [41]:
from torch.utils.data import Dataset
from PIL import Image
import torch

class LeukemiaAnswerOnlyDataset(Dataset):
    def __init__(self, image_paths, labels, processor, max_length=256):
        self.image_paths = list(image_paths)
        self.labels = [int(x) for x in labels]
        self.processor = processor
        self.max_length = max_length
        self.answers = {0: "Normal", 1: "Leukemia"}

        self.user_text = (
            "Classify this blood cell microscopy image.\n"
            "Answer with exactly ONE word: Normal or Leukemia.\n"
            "Answer:"
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        y = self.labels[idx]
        answer = self.answers[y]

        # full chat with answer
        msgs_full = [
            {"role": "user", "content": [{"type":"image"}, {"type":"text", "text": self.user_text}]},
            {"role": "assistant", "content": [{"type":"text", "text": answer}]},
        ]
        text_full = self.processor.apply_chat_template(msgs_full, add_generation_prompt=False)

        # prompt-only chat (empty answer)
        msgs_prompt = [
            {"role": "user", "content": [{"type":"image"}, {"type":"text", "text": self.user_text}]},
            {"role": "assistant", "content": [{"type":"text", "text": ""}]},
        ]
        text_prompt = self.processor.apply_chat_template(msgs_prompt, add_generation_prompt=False)

        full = self.processor(
            images=image, text=text_full,
            return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length
        )
        prompt = self.processor(
            images=image, text=text_prompt,
            return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length
        )

        input_ids = full["input_ids"].squeeze(0)
        attention_mask = full["attention_mask"].squeeze(0)
        pixel_values = full["pixel_values"].squeeze(0)

        # Answer-only loss
        labels = input_ids.clone()
        prompt_len = prompt["input_ids"].shape[1]
        labels[:prompt_len] = -100
        labels[attention_mask == 0] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "labels": labels,
        }


In [51]:
import torch

def gemma_mm_pad_collator(features):
    # Max length based on input_ids
    max_len = max(f["input_ids"].shape[0] for f in features)

    pad_id = getattr(processor.tokenizer, "pad_token_id", 0)
    if pad_id is None:
        pad_id = 0

    def pad_1d(x, pad_value):
        pad_len = max_len - x.shape[0]
        if pad_len <= 0:
            return x
        return torch.cat([x, torch.full((pad_len,), pad_value, dtype=x.dtype, device=x.device)], dim=0)

    def pad_2d(x, pad_value):
        # pad on first dimension only (seq_len, hidden/whatever)
        pad_len = max_len - x.shape[0]
        if pad_len <= 0:
            return x
        pad_rows = torch.full((pad_len, x.shape[1]), pad_value, dtype=x.dtype, device=x.device)
        return torch.cat([x, pad_rows], dim=0)

    batch = {}
    keys = features[0].keys()

    for k in keys:
        vals = [f[k] for f in features]

        # Non-tensors: keep as list
        if not torch.is_tensor(vals[0]):
            batch[k] = vals
            continue

        # If scalar tensors: stack directly
        if vals[0].ndim == 0:
            batch[k] = torch.stack(vals)
            continue

        # If 1D tensors and lengths vary: pad to max_len
        if vals[0].ndim == 1:
            lens = [v.shape[0] for v in vals]
            if len(set(lens)) > 1:
                if k == "input_ids":
                    pad_value = pad_id
                elif k == "attention_mask":
                    pad_value = 0
                elif k == "labels":
                    pad_value = -100
                else:
                    pad_value = 0  # safe default for extra seq fields (position_ids, etc.)
                vals = [pad_1d(v, pad_value) for v in vals]
            batch[k] = torch.stack(vals)
            continue

        # If 2D tensors and first dim varies: pad to max_len on dim0
        if vals[0].ndim == 2:
            lens = [v.shape[0] for v in vals]
            if len(set(lens)) > 1:
                # labels-like 2D is rare; 0 is usually safe
                vals = [pad_2d(v, 0) for v in vals]
            batch[k] = torch.stack(vals)
            continue

        # Higher dims (pixel_values etc.) should already be same shape
        batch[k] = torch.stack(vals)

    return batch


In [63]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from transformers import TrainingArguments, Trainer, default_data_collator
from collections import Counter

# -------------------------
# 1) Balanced sampler (train only)
# -------------------------
def make_balanced_sampler(labels):
    counts = Counter([int(x) for x in labels])
    w0 = 1.0 / max(counts.get(0, 1), 1)
    w1 = 1.0 / max(counts.get(1, 1), 1)
    weights = [w0 if int(y) == 0 else w1 for y in labels]
    return WeightedRandomSampler(
        torch.DoubleTensor(weights),
        num_samples=len(weights),
        replacement=True
    )

train_sampler = make_balanced_sampler(train_labels)
print("‚úÖ Sampler ready (balanced batches)")

# -------------------------
# 2) Dataset (FIXED: proper answer-only loss masking)
# -------------------------
class LeukemiaAnswerOnlyDataset(Dataset):
    def __init__(self, images, labels, processor):
        self.images = images
        self.labels = [int(x) for x in labels]
        self.processor = processor
        self.answe0rs = {0: "Normal", 1: "Leukemia"}

        self.user_text = (
            "Classify this blood cell microscopy image.\n"
            "Answer with exactly ONE word: Normal or Leukemia.\n"
            "Answer:"
        )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        y = self.labels[idx]
        answer = self.answers[y]

        # ‚úÖ Full message WITH answer
        msgs_full = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": self.user_text}]},
            {"role": "assistant", "content": [{"type": "text", "text": answer}]},
        ]

        # ‚úÖ FIX: Prompt only (NO assistant turn, let add_generation_prompt add it)
        msgs_prompt = [
            {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": self.user_text}]},
        ]

        # ‚úÖ FIX: add_generation_prompt=True for prompt marks where answer starts
        text_full = self.processor.apply_chat_template(msgs_full, add_generation_prompt=False)
        text_prompt = self.processor.apply_chat_template(msgs_prompt, add_generation_prompt=True)  # ‚Üê KEY FIX

        full = self.processor(
            images=image,
            text=text_full,
            return_tensors="pt",
            padding=False,
            truncation=False,
        )
        prompt = self.processor(
            images=image,
            text=text_prompt,
            return_tensors="pt",
            padding=False,
            truncation=False,
        )

        # Squeeze batch dim
        full = {k: v.squeeze(0) for k, v in full.items()}

        # ‚úÖ FIX: Use shape[1] for prompt length (before squeeze)
        prompt_len = prompt["input_ids"].shape[1]

        # Answer-only loss: mask everything up to prompt length
        labels = full["input_ids"].clone()
        labels[:prompt_len] = -100
        full["labels"] = labels

        return full

train_dataset = LeukemiaAnswerOnlyDataset(train_images, train_labels, processor)
val_dataset   = LeukemiaAnswerOnlyDataset(val_images, val_labels, processor)

print("‚úÖ Datasets ready")
print("Train:", len(train_dataset), " Val:", len(val_dataset))

# ‚úÖ Verify fix
sample = train_dataset[0]
trainable = sample["labels"][sample["labels"] != -100]
print(f"‚úÖ Trainable tokens: {len(trainable)}")
print(f"‚úÖ Decoded: {processor.decode(trainable)}")

# -------------------------
# 3) Custom Trainer to inject sampler
# -------------------------
class BalancedTrainer(Trainer):
    def get_train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            sampler=train_sampler,
            collate_fn=gemma_mm_pad_collator,
            num_workers=2,
            pin_memory=True,
        )

# -------------------------
# 4) TrainingArguments
# -------------------------
OUTPUT_DIR = "/content/medgemma_lora_run_v2"

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,

    num_train_epochs=5,
    learning_rate=2e-5,
    warmup_ratio=0.05,
    weight_decay=0.01,

    bf16=True,

    logging_steps=25,
    eval_strategy="steps",
    eval_steps=150,
    save_strategy="steps",
    save_steps=150,
    save_total_limit=3,

    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    remove_unused_columns=False,
    report_to="none",
)

# -------------------------
# 5) Build trainer
# -------------------------
trainer = BalancedTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=gemma_mm_pad_collator,
)

print("‚úÖ Trainer ready - run trainer.train()")

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


‚úÖ Sampler ready (balanced batches)
‚úÖ Datasets ready
Train: 13951  Val: 1744
‚úÖ Trainable tokens: 5
‚úÖ Decoded: Leukemia<end_of_turn>

‚úÖ Trainer ready - run trainer.train()


In [64]:
# -------------------------
# 6) Train
# -------------------------
print("\nüöÄ Starting LoRA fine-tuning...")
print("Checkpoints will be saved to:", OUTPUT_DIR)
print("-" * 60)

trainer.train()

print("\n‚úÖ Training complete!")
print("Best checkpoint:", trainer.state.best_model_checkpoint)




üöÄ Starting LoRA fine-tuning...
Checkpoints will be saved to: /content/medgemma_lora_run_v2
------------------------------------------------------------


Step,Training Loss,Validation Loss
150,0.136213,0.122787
300,0.114943,0.087362
450,0.063766,0.066102
600,0.05537,0.052734
750,0.050667,0.055197
900,0.043557,0.050898
1050,0.027168,0.043081
1200,0.045397,0.057755
1350,0.028797,0.034892
1500,0.022488,0.034721



‚úÖ Training complete!
Best checkpoint: /content/medgemma_lora_run_v2/checkpoint-1800


In [65]:
# -------------------------
# 7) Save final model + processor
# -------------------------
final_dir = "/content/medgemma_lora_final"
trainer.save_model(final_dir)
processor.save_pretrained(final_dir)

print("‚úÖ Saved final LoRA adapter + processor to:", final_dir)

‚úÖ Saved final LoRA adapter + processor to: /content/medgemma_lora_final


In [66]:
# OPTIONAL: Push to Hugging Face Hub
# Uncomment to upload your model

HF_REPO = "good2idnan/medgemma-1.5-4b-it-leukemia-lora"
model.push_to_hub(HF_REPO)
processor.push_to_hub(HF_REPO)
print(f"‚úÖ Pushed to: https://huggingface.co/good2idnan/medgemma-1.5-4b-it-leukemia-lora")

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...adapter_model.safetensors:   1%|          |  563kB / 95.3MB            

README.md: 0.00B [00:00, ?B/s]

Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...mpsjz33r2e/tokenizer.json:  23%|##2       | 7.61MB / 33.4MB            

‚úÖ Pushed to: https://huggingface.co/good2idnan/medgemma-1.5-4b-it-leukemia-lora


In [2]:
from sklearn.metrics import classification_report, f1_score, confusion_matrix
import torch

model.eval()
predictions, truths = [], []

# Evaluate on validation set
for i in range(len(val_images)):
    msgs = [{"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "Classify this blood cell microscopy image.\nAnswer with exactly ONE word: Normal or Leukemia.\nAnswer:"}
    ]}]
    text = processor.apply_chat_template(msgs, add_generation_prompt=True)
    inputs = processor(images=val_images[i], text=text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=5, do_sample=False)

    response = processor.decode(output[0], skip_special_tokens=True).lower()
    pred = 1 if "leukemia" in response else 0

    predictions.append(pred)
    truths.append(int(val_labels[i]))

    if (i + 1) % 200 == 0:
        print(f"Evaluated {i+1}/{len(val_images)}")

# Print results
print("\n" + "="*60)
print("FINAL EVALUATION RESULTS")
print("="*60)
print(classification_report(truths, predictions, target_names=["Normal", "Leukemia"]))
print(f"üéØ Weighted F1 Score: {f1_score(truths, predictions, average='weighted'):.4f}")
print(f"\nConfusion Matrix:\n{confusion_matrix(truths, predictions)}")

NameError: name 'model' is not defined