# NIH Few-Shot (One-Shot) Classification with MedCLIP

This notebook fine-tunes a lightweight linear head on top of a frozen MedCLIP image encoder using one/few-shot samples per class and evaluates on the held-out test split generated earlier.

In [17]:
# ========= 1) Setup & Config =========
import os
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report
from IPython.display import display

from medclip import MedCLIPModel, MedCLIPProcessor
from medclip.modeling_medclip import SuperviseClassifier

# Prefer Apple Silicon GPU or CUDA when available
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

THIS_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd().resolve()
TRAIN_CSV = (THIS_DIR / "local_data" / "nih-sampled-meta-train.csv").resolve()
TEST_CSV = (THIS_DIR / "local_data" / "nih-sampled-meta-test.csv").resolve()
IMAGE_ROOT = THIS_DIR  # imgpath column already stores repo-relative paths

CHEXPERT5 = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]

VISION_MODEL = "vit"                      # or "resnet"
SHOT_PER_CLASS = 5                          # one-shot by default; increase (e.g., 5) for few-shot
BATCH_SIZE = 4
NUM_EPOCHS = 50
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-4
RANDOM_STATE = 42

print("Torch:", torch.__version__)
print("Device:", DEVICE)
print("Train CSV:", TRAIN_CSV)
print("Test  CSV:", TEST_CSV)


Torch: 2.5.1
Device: mps
Train CSV: /Users/zitongluo/Library/Mobile Documents/com~apple~CloudDocs/硕士相关/2025Fall/Learning from small data/MedCLIP_eval/local_data/nih-sampled-meta-train.csv
Test  CSV: /Users/zitongluo/Library/Mobile Documents/com~apple~CloudDocs/硕士相关/2025Fall/Learning from small data/MedCLIP_eval/local_data/nih-sampled-meta-test.csv


In [18]:
# ========= 2) Load train/test metadata =========
def resolve_path(p: str) -> Path:
    raw = Path(p)
    return raw.resolve() if raw.is_absolute() else (IMAGE_ROOT / raw).resolve()

train_df = pd.read_csv(TRAIN_CSV)
test_df = pd.read_csv(TEST_CSV)

for name, df in (("train", train_df), ("test", test_df)):
    if "disease" not in df.columns or "imgpath" not in df.columns:
        raise KeyError(f"{name} CSV must contain 'disease' and 'imgpath' columns.")
    df["disease"] = df["disease"].astype(str)
    df["img_abs_path"] = df["imgpath"].apply(lambda p: str(resolve_path(p)))
    exists = df["img_abs_path"].apply(lambda p: Path(p).exists())
    missing = len(df) - exists.sum()
    if missing:
        print(f"[WARN] {missing} files missing in {name} split; they will be dropped.")
        display(df.loc[~exists, ["Image Index", "img_abs_path"]].head(10))
    df.drop(index=df.index[~exists], inplace=True)
    df.reset_index(drop=True, inplace=True)

# Keep disease-only five-class subset
train_df = train_df.loc[train_df["disease"].isin(CHEXPERT5)].reset_index(drop=True)
test_df = test_df.loc[test_df["disease"].isin(CHEXPERT5)].reset_index(drop=True)

print("Train size after filtering:", len(train_df))
print(train_df["disease"].value_counts().sort_index())
print("Test size after filtering:", len(test_df))
print(test_df["disease"].value_counts().sort_index())


Train size after filtering: 2000
disease
Atelectasis         400
Cardiomegaly        400
Consolidation       400
Edema               400
Pleural Effusion    400
Name: count, dtype: int64
Test size after filtering: 5000
disease
Atelectasis         1000
Cardiomegaly        1000
Consolidation       1000
Edema               1000
Pleural Effusion    1000
Name: count, dtype: int64


In [19]:
# ========= 3) Sample one/few-shot training subset =========
def build_few_shot(df: pd.DataFrame, shots: int, seed: int) -> pd.DataFrame:
    if shots < 1:
        raise ValueError("shots must be >= 1")
    sampled_frames = []
    for label, group in df.groupby("disease"):
        take = min(shots, len(group))
        sampled = group.sample(n=take, random_state=seed, replace=False)
        sampled_frames.append(sampled)
    few = pd.concat(sampled_frames, ignore_index=True)
    few = few.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return few

few_shot_train_df = build_few_shot(train_df, SHOT_PER_CLASS, RANDOM_STATE)
if few_shot_train_df.empty:
    raise RuntimeError("Few-shot training set is empty; check SHOT_PER_CLASS and data filtering.")

print("Few-shot train size:", len(few_shot_train_df))
print(few_shot_train_df["disease"].value_counts().sort_index())


Few-shot train size: 25
disease
Atelectasis         5
Cardiomegaly        5
Consolidation       5
Edema               5
Pleural Effusion    5
Name: count, dtype: int64


In [20]:
# ========= 4) Dataset & Dataloader helpers =========
class NIHSingleLabelDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, class_to_idx: Dict[str, int], processor: MedCLIPProcessor):
        self.df = dataframe.reset_index(drop=True)
        self.class_to_idx = class_to_idx
        self.processor = processor

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        with Image.open(row["img_abs_path"]) as img:
            image = img.convert("RGB")
        processed = self.processor(images=image, return_tensors="pt")
        pixel_values = processed["pixel_values"].squeeze(0)
        label_idx = self.class_to_idx[row["disease"]]
        return {
            "pixel_values": pixel_values,
            "label": torch.tensor(label_idx, dtype=torch.long),
            "path": row["img_abs_path"],
        }

def collate_fn(batch):
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    paths = [item["path"] for item in batch]
    return {"pixel_values": pixel_values, "labels": labels, "paths": paths}

processor = MedCLIPProcessor()

CLASS_TO_IDX = {cls: idx for idx, cls in enumerate(CHEXPERT5)}
IDX_TO_CLASS = {idx: cls for cls, idx in CLASS_TO_IDX.items()}

train_dataset = NIHSingleLabelDataset(few_shot_train_df, CLASS_TO_IDX, processor)
test_dataset = NIHSingleLabelDataset(test_df, CLASS_TO_IDX, processor)

train_batch_size = max(1, min(BATCH_SIZE, len(train_dataset)))
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print("Train samples:", len(train_dataset), "| Batches:", len(train_loader))
print("Test samples:", len(test_dataset), "| Batches:", len(test_loader))




Train samples: 25 | Batches: 7
Test samples: 5000 | Batches: 1250


In [21]:
# ========= 5) Load MedCLIP & build supervised head =========
base_model = MedCLIPModel.from_pretrained(vision_model=VISION_MODEL, device=DEVICE)

vision_encoder = base_model.vision_model
vision_encoder.to(DEVICE)
setattr(vision_encoder, "device", torch.device(DEVICE))

classifier_input_dim = 768 if VISION_MODEL == "vit" else 512
supervised = SuperviseClassifier(
    vision_model=vision_encoder,
    num_class=len(CHEXPERT5),
    input_dim=classifier_input_dim,
    mode="multiclass",
).to(DEVICE)

for param in supervised.model.parameters():
    param.requires_grad = False
supervised.model.eval()

supervised.fc.reset_parameters()

optimizer = torch.optim.Adam(supervised.fc.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

print("Trainable parameters in head:", sum(p.numel() for p in supervised.fc.parameters()))


Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.d

Model moved to mps
load model weight from: pretrained/medclip-vit
Trainable parameters in head: 3845


In [22]:
# ========= 6) Fine-tune linear head (few-shot) =========
torch.manual_seed(RANDOM_STATE)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(RANDOM_STATE)

loss_history = []
for epoch in range(1, NUM_EPOCHS + 1):
    supervised.train()
    supervised.model.eval()
    running_loss = 0.0
    count = 0
    for batch in train_loader:
        pixel_values = batch["pixel_values"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        optimizer.zero_grad()
        outputs = supervised(pixel_values=pixel_values, labels=labels, return_loss=True)
        loss = outputs["loss_value"]
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * pixel_values.size(0)
        count += pixel_values.size(0)
    epoch_loss = running_loss / max(1, count)
    loss_history.append(epoch_loss)
    print(f"Epoch {epoch:02d}/{NUM_EPOCHS} | loss={epoch_loss:.4f}")

print("Training done.")


Epoch 01/50 | loss=1.6638
Epoch 02/50 | loss=1.5094
Epoch 03/50 | loss=1.3824
Epoch 04/50 | loss=1.2939
Epoch 05/50 | loss=1.2062
Epoch 06/50 | loss=1.1311
Epoch 07/50 | loss=1.0592
Epoch 08/50 | loss=0.9972
Epoch 09/50 | loss=0.9454
Epoch 10/50 | loss=0.8945
Epoch 11/50 | loss=0.8503
Epoch 12/50 | loss=0.7997
Epoch 13/50 | loss=0.7691
Epoch 14/50 | loss=0.7410
Epoch 15/50 | loss=0.7039
Epoch 16/50 | loss=0.6692
Epoch 17/50 | loss=0.6429
Epoch 18/50 | loss=0.6182
Epoch 19/50 | loss=0.5963
Epoch 20/50 | loss=0.5726
Epoch 21/50 | loss=0.5512
Epoch 22/50 | loss=0.5294
Epoch 23/50 | loss=0.5114
Epoch 24/50 | loss=0.4939
Epoch 25/50 | loss=0.4748
Epoch 26/50 | loss=0.4588
Epoch 27/50 | loss=0.4452
Epoch 28/50 | loss=0.4277
Epoch 29/50 | loss=0.4164
Epoch 30/50 | loss=0.4010
Epoch 31/50 | loss=0.3905
Epoch 32/50 | loss=0.3784
Epoch 33/50 | loss=0.3638
Epoch 34/50 | loss=0.3524
Epoch 35/50 | loss=0.3441
Epoch 36/50 | loss=0.3328
Epoch 37/50 | loss=0.3262
Epoch 38/50 | loss=0.3149
Epoch 39/50 

In [23]:
# ========= 7) Evaluation =========
def predict(dataloader):
    supervised.eval()
    supervised.model.eval()
    logits_list = []
    labels_list = []
    paths = []
    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            outputs = supervised(pixel_values=pixel_values, labels=None, return_loss=False)
            logits_list.append(outputs["logits"].detach().cpu())
            labels_list.append(labels.detach().cpu())
            paths.extend(batch["paths"])
    if not logits_list:
        return np.empty((0, len(CHEXPERT5))), np.empty((0,), dtype=int), paths
    logits = torch.cat(logits_list, dim=0).numpy()
    labels = torch.cat(labels_list, dim=0).numpy()
    return logits, labels, paths

train_logits, train_labels, _ = predict(train_loader)
test_logits, test_labels, test_paths = predict(test_loader)


def summarize(split: str, logits: np.ndarray, labels: np.ndarray):
    if logits.shape[0] == 0:
        print(f"{split}: no samples to evaluate.")
        return None
    preds = logits.argmax(axis=1)
    acc = accuracy_score(labels, preds)
    print(f"{split} accuracy: {acc:.4f} ({len(labels)} samples)")
    report = classification_report(labels, preds, target_names=CHEXPERT5, digits=4, zero_division=0)
    print(report)
    return preds

train_preds = summarize("Train", train_logits, train_labels)
test_preds = summarize("Test", test_logits, test_labels)


Train accuracy: 1.0000 (25 samples)
                  precision    recall  f1-score   support

     Atelectasis     1.0000    1.0000    1.0000         5
    Cardiomegaly     1.0000    1.0000    1.0000         5
   Consolidation     1.0000    1.0000    1.0000         5
           Edema     1.0000    1.0000    1.0000         5
Pleural Effusion     1.0000    1.0000    1.0000         5

        accuracy                         1.0000        25
       macro avg     1.0000    1.0000    1.0000        25
    weighted avg     1.0000    1.0000    1.0000        25

Test accuracy: 0.4804 (5000 samples)
                  precision    recall  f1-score   support

     Atelectasis     0.4963    0.4700    0.4828      1000
    Cardiomegaly     0.6348    0.6310    0.6329      1000
   Consolidation     0.3059    0.2190    0.2552      1000
           Edema     0.4611    0.4980    0.4788      1000
Pleural Effusion     0.4624    0.5840    0.5161      1000

        accuracy                         0.4804     

In [9]:
# ========= 8) Preview predictions =========
if test_preds is not None and len(test_preds) > 0:
    probs = torch.softmax(torch.tensor(test_logits), dim=1).numpy()
    records = []
    for idx, (path, true_idx, pred_idx) in enumerate(zip(test_paths, test_labels, test_preds)):
        row = {
            "img_path": path,
            "true_label": IDX_TO_CLASS[int(true_idx)],
            "pred_label": IDX_TO_CLASS[int(pred_idx)],
        }
        for class_idx, class_name in enumerate(CHEXPERT5):
            row[f"prob_{class_name.replace(' ', '_')}"] = float(probs[idx, class_idx])
        records.append(row)
    preview_df = pd.DataFrame(records)
    display(preview_df.head(10))
else:
    print("No test predictions available to preview.")


Unnamed: 0,img_path,true_label,pred_label,prob_Atelectasis,prob_Cardiomegaly,prob_Consolidation,prob_Edema,prob_Pleural_Effusion
0,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Pleural Effusion,0.251663,0.152476,0.127819,0.184181,0.28386
1,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Consolidation,Edema,0.056459,0.041152,0.403942,0.433222,0.065224
2,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Edema,Consolidation,0.08083,0.089066,0.435175,0.32316,0.071768
3,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Edema,0.132853,0.125199,0.208922,0.305372,0.227654
4,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Cardiomegaly,Cardiomegaly,0.089542,0.580331,0.101627,0.143153,0.085347
5,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Cardiomegaly,Edema,0.131688,0.117355,0.228223,0.367664,0.155069
6,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Atelectasis,Cardiomegaly,0.186596,0.279761,0.174641,0.167292,0.191709
7,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Consolidation,Pleural Effusion,0.121993,0.086548,0.201292,0.290766,0.299401
8,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Pleural Effusion,0.347894,0.094248,0.083082,0.104594,0.370182
9,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Edema,Edema,0.189799,0.180819,0.220937,0.240289,0.168155
