# Indiana

In [46]:
# ========= 1) Setup & Config (MPS first, disease-only) =========
import os
import re
from pathlib import Path
from collections import OrderedDict
from typing import List

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
from sklearn.metrics import accuracy_score

# MedCLIP
from medclip import MedCLIPModel, MedCLIPProcessor
from medclip import PromptClassifier
try:
    from medclip.prompts import generate_chexpert_class_prompts, process_class_prompts
    HAS_PROMPT_UTILS = True
except Exception:
    HAS_PROMPT_UTILS = False

# --- Device picking (prefer MPS on Apple Silicon) ---
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

# --- Config ---
THIS_DIR   = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd().resolve()
CSV_PATH   = (THIS_DIR / "indiana_sampled_500_images" / "indiana_sampled_500_images.csv").resolve()
IMAGE_ROOT = (THIS_DIR / "indiana_sampled_500_images" / "images_normalized").resolve()

VISION_MODEL = "vit"                       # "vit" or "resnet"
BATCH_SIZE   = 16
NUM_PROMPTS_PER_CLASS = 10                 # used if medclip.prompts is available

CHEXPERT5 = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]  # disease-only

print("Torch:", torch.__version__)
print("Using device:", DEVICE)
print("CSV_PATH:", CSV_PATH)
print("IMAGE_ROOT:", IMAGE_ROOT)

Torch: 2.5.1
Using device: mps
CSV_PATH: /Users/zitongluo/Library/Mobile Documents/com~apple~CloudDocs/硕士相关/2025Fall/Learning from small data/MedCLIP_eval/indiana_sampled_500_images/indiana_sampled_500_images.csv
IMAGE_ROOT: /Users/zitongluo/Library/Mobile Documents/com~apple~CloudDocs/硕士相关/2025Fall/Learning from small data/MedCLIP_eval/indiana_sampled_500_images/images_normalized


In [47]:
# ========= 2) Load CSV & Build Image Paths =========
from IPython.display import display

df = pd.read_csv(CSV_PATH)
if "filename" not in df.columns:
    raise ValueError("CSV must contain a 'filename' column (e.g., 10_IM-0002-1001.dcm.png).")

# Join IMAGE_ROOT/filename for each row
df["img_path"] = df["filename"].astype(str).apply(lambda x: str((IMAGE_ROOT / x).resolve()))

# Keep only existing images
exists_mask = df["img_path"].apply(lambda p: Path(p).exists())
print(f"[Path check] {exists_mask.sum()}/{len(df)} files exist; missing={(~exists_mask).sum()}")
if (~exists_mask).any():
    display(df.loc[~exists_mask, ["filename", "img_path"]].head(10))

df = df.loc[exists_mask].reset_index(drop=True)
print("Using images:", len(df))

[Path check] 500/500 files exist; missing=0
Using images: 500


In [48]:
# ========= 3) Derive / Normalize Labels (CheXpert-5 + Normal for derivation) =========
def _lower(s):
    return "" if not isinstance(s, str) else s.lower()

def parse_labels_field(x):
    """Parse a label field into a list of strings:
       - list/tuple/ndarray -> strip each
       - 'A|B' string -> ['A','B']
       - other -> []"""
    if isinstance(x, (list, tuple, np.ndarray)):
        return [str(s).strip() for s in x if str(s).strip()]
    if isinstance(x, str):
        return [t.strip() for t in x.split("|") if t.strip()]
    return []

# Keyword mapping for CheXpert-5 (lowercased)
mapping_5 = {
    "Atelectasis": [
        r"\batelectasis\b", r"\batelectatic\b",
        r"\b(subsegmental|plate[- ]?like|linear|discoid|band[- ]?like|bibasilar|basilar|dependent|lingular|rml|right middle lobe)\s+atelectasis\b",
        r"\b(lobar|segmental|partial)\s+collapse\b",
        r"\bvolume\s+loss\b",
        r"\b(streaky|bandlike|subsegmental)\s+opacities?\s+(at|in)\s+(the\s+)?bases?\b",
    ],
    "Cardiomegaly": [
        r"\bcardiomegaly\b",
        r"\benlarged\s+(cardiac|cardiomediastinal)\s+silhouette\b",
        r"\bcardiac\s+enlargement\b",
        r"\bcardiomediastinal\s+(enlargement|widening|prominen(ce|t))\b",
        r"\b(cardiac|cardio(thoracic)?)\s+silhouette\s+is\s+enlarged\b",
        r"\bheart\s+size\s+(is\s+)?(mildly|moderately|severely)?\s*enlarged\b",
        r"\bcardiothoracic\s+ratio\s+(increased|>?\s*0\.[56-9]|>\s*55\%)\b",
    ],
    "Consolidation": [
        r"\b(consolidation|consolidative)\b",
        r"\bpneumonia\b",
        r"\b(infiltrate|infiltrates)\b",
        r"\bair\s*space\s+(disease|opacity|opacities)\b",
        r"\bparenchymal\s+opacity\b",
        r"\bair\s*bronchogram(s)?\b",
        r"\b(focal|patchy|multifocal)\s+opacity(ies)?\b",
        r"\bopacification(s)?\b",
        r"\balveolar\s+(opacity|opacities|consolidation)\b",
        r"\blobar\s+(consolidation|pneumonia)\b",
        r"\b(lingular|rml|right middle lobe|lobe)\s+opacity\b",
        r"\bopacity(ies)?\b",
    ],
    "Edema": [
        r"\bpulmonary\s+ed(ema|oema)\b",
        r"\b(interstitial|alveolar)\s+ed(ema|oema)\b",
        r"\b(pulmonary|venous|vascular)\s+congestion\b",
        r"\bperihilar\s+(haze|opacit(y|ies))\b",
        r"\bkerley\s*(a|b)\s*lines?\b",
        r"\bfluid\s+overload\b",
        r"\bcongestive\s+heart\s+failure\b",
        r"\bbat[- ]?wing\s+pattern\b",
        r"\bcephalization\b",
        r"\bpulmonary\s+venous\s+hypertension\b",
        r"\binterstitial\s+markings\s+increased\b",
        r"\bperibronchial\s+cuff(ing)?\b",
    ],
    "Pleural Effusion": [
        r"\bpleural\s+effusion(s)?\b",
        r"\b(hydrothorax|pleural\s+fluid)\b",
        r"\b(blunted|blunting)\s+(costophrenic|cp)\s+angles?\b",
        r"\bmeniscus\s+sign\b",
        r"\b(?:layering|small|trace|minimal|tiny)\s+effusion(s)?\b",
        r"\b(bilateral|left|right)\s+pleural\s+effusions?\b",
        r"\bpleural\s+thickening\s+with\s+effusion\b",
    ],
}
normal_patterns = [r"^\s*normal\s*$", r"\bnormal\b"]  # used only for derivation; will be excluded later

def derive_chexpert5_normal(row):
    """Return a list of matched labels from CheXpert-5; if none matched, try 'Normal' in Problems."""
    probs = _lower(row.get("Problems", ""))
    mesh  = _lower(row.get("MeSH", ""))
    text_all = f"{probs} {mesh}"
    labels = []
    for cls, pats in mapping_5.items():
        if any(re.search(p, text_all) for p in pats):
            labels.append(cls)
    if not labels:  # only if no abnormal class matched
        if any(re.search(p, probs) for p in normal_patterns):
            labels.append("Normal")
    return labels

# Decide whether CSV already provides labels
HAS_SINGLE = ("label" in df.columns) and df["label"].notna().any()
HAS_MULTI  = (("labels" in df.columns) and df["labels"].astype(str).str.len().gt(0).any()) \
             or (("labels_list" in df.columns) and df["labels_list"].apply(lambda x: isinstance(x,(list,tuple,np.ndarray)) and len(x)>0).any())

if not (HAS_SINGLE or HAS_MULTI):
    labs = df.apply(derive_chexpert5_normal, axis=1)
    df["labels_list"] = labs
    df["labels"] = df["labels_list"].apply(lambda xs: "|".join(xs))
    print("[Info] Derived labels into 'labels' / 'labels_list' (CheXpert-5 + Normal).")
else:
    if "labels_list" not in df.columns and "labels" in df.columns:
        df["labels_list"] = df["labels"].apply(parse_labels_field)
    if "labels" not in df.columns and "labels_list" in df.columns:
        df["labels"] = df["labels_list"].apply(lambda xs: "|".join(xs))

print("Label columns now:", [c for c in ["label", "labels", "labels_list"] if c in df.columns])

[Info] Derived labels into 'labels' / 'labels_list' (CheXpert-5 + Normal).
Label columns now: ['labels', 'labels_list']


In [49]:
# ========= 4) Keep single-label only, then disease-only (exclude 'Normal') =========
# Ensure there is a labels_list column
if "labels_list" not in df.columns:
    if "labels" in df.columns:
        df["labels_list"] = df["labels"].apply(parse_labels_field)
    elif "label" in df.columns:
        df["labels_list"] = df["label"].fillna("").astype(str).apply(lambda s: [s.strip()] if s.strip() else [])
    else:
        df["labels_list"] = [[] for _ in range(len(df))]

def _len_nonempty(x):
    if isinstance(x, (list, tuple, np.ndarray)):
        return len([t for t in x if str(t).strip()])
    return 0

# Keep exactly one label per sample
df["num_labels"] = df["labels_list"].apply(_len_nonempty)
before = len(df)
df = df.loc[df["num_labels"] == 1].reset_index(drop=True)
after = len(df)
print(f"[Filter] Kept single-label rows: {after}/{before} (dropped {before - after}).")

# Normalize to 'label' (single string) and 'labels' (same as label)
df["label"]  = df["labels_list"].apply(lambda xs: xs[0] if isinstance(xs, (list, tuple, np.ndarray)) and len(xs) == 1 else "")
df["labels"] = df["label"]
df.drop(columns=["num_labels"], inplace=True)

# Keep disease-only: drop 'Normal'
before2 = len(df)
df = df[df["label"].isin(CHEXPERT5)].reset_index(drop=True)
after2 = len(df)
print(f"[Filter] Kept disease-only rows (exclude 'Normal'): {after2}/{before2} (dropped {before2 - after2}).")

if len(df) == 0:
    raise RuntimeError("No diseased single-label data remains after filtering.")

[Filter] Kept single-label rows: 237/500 (dropped 263).
[Filter] Kept disease-only rows (exclude 'Normal'): 63/237 (dropped 174).


In [50]:
# ========= 5) Load MedCLIP (model + classifier) =========
processor = MedCLIPProcessor()
try:
    model = MedCLIPModel.from_pretrained(vision_model=VISION_MODEL, device=DEVICE)
except TypeError:
    name_map = {"vit": "medclip-vit", "resnet": "medclip-resnet50"}
    model = MedCLIPModel.from_pretrained(name_map.get(VISION_MODEL.lower(), VISION_MODEL))
# Ensure the model is moved to the chosen device (important for MPS)
model = model.to(DEVICE)
model.eval()

clf = PromptClassifier(model, ensemble=True).to(DEVICE)

Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weig

Model moved to mps
load model weight from: pretrained/medclip-vit


In [51]:
# ========= 6) Build Prompts (CheXpert-5 only; no Normal) =========
# Keep prompts as dict: class_name -> tokenized tensors (format expected by PromptClassifier)

if HAS_PROMPT_UTILS:
    cls_prompts_5 = process_class_prompts(generate_chexpert_class_prompts(n=NUM_PROMPTS_PER_CLASS))
    if not isinstance(cls_prompts_5, dict):
        raise TypeError(f"process_class_prompts should return dict, got {type(cls_prompts_5)}")
else:
    base = [
        "A chest X-ray showing {label}.",
        "The radiograph demonstrates {label}.",
        "CXR with finding: {label}.",
        "This image indicates {label}.",
    ]
    cls_prompts_5 = {}
    for lbl in CHEXPERT5:
        texts = [t.format(label=lbl) for t in base]
        cls_prompts_5[lbl] = processor(text=texts, return_tensors="pt", padding=True)

# Ordered mapping with stable class order (5 classes only)
prompt_inputs = OrderedDict()
for c in CHEXPERT5:
    if c not in cls_prompts_5:
        raise KeyError(f"Missing prompts for class: {c}")
    prompt_inputs[c] = cls_prompts_5[c]

user_class_names = list(prompt_inputs.keys())  # == CHEXPERT5
print("Prompts built for classes:", user_class_names)
print("Per-class prompt counts:", [prompt_inputs[c]["input_ids"].shape[0] for c in user_class_names])

sample 10 num of prompts for Atelectasis from total 210
sample 10 num of prompts for Cardiomegaly from total 15
sample 10 num of prompts for Consolidation from total 192
sample 10 num of prompts for Edema from total 18
sample 10 num of prompts for Pleural Effusion from total 54
Prompts built for classes: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
Per-class prompt counts: [10, 10, 10, 10, 10]


In [52]:
# ========= 7) Inference (Top-1 over 5 classes) — no file export =========
from IPython.display import display

def batched(seq, n):
    for i in range(0, len(seq), n):
        yield seq[i:i+n], list(range(i, min(i+n, len(seq))))

img_paths = df["img_path"].tolist()
all_logits = []
processed_indices = []
skipped = []

total_batches = (len(img_paths) + BATCH_SIZE - 1) // BATCH_SIZE
for batch_files, idxs in tqdm(batched(img_paths, BATCH_SIZE), total=total_batches):
    images = []
    kept_local = []
    for j, p in enumerate(batch_files):
        try:
            with Image.open(p) as im:
                images.append(im.convert("RGB"))
            kept_local.append(idxs[j])
        except Exception as e:
            skipped.append((p, repr(e)))

    if not images:
        continue

    inputs = processor(images=images, return_tensors="pt")
    inputs["prompt_inputs"] = prompt_inputs
    inputs = {k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}

    with torch.no_grad():
        out = clf(**inputs)
        batch_logits = out["logits"].detach().cpu().numpy()  # shape: [b, 5]

    if 'CLASS_NAMES' not in globals():
        CLASS_NAMES = out.get("class_names", CHEXPERT5)  # FYI
        print("Model's class_names (5):", CLASS_NAMES)

    all_logits.append(batch_logits)
    processed_indices.extend(kept_local)

if len(all_logits) == 0:
    raise RuntimeError("No images were processed. Check IMAGE_ROOT/filenames.")

logits = np.concatenate(all_logits, axis=0)   # [N_used, 5]
df_infer = df.iloc[processed_indices].reset_index(drop=True)

print("Inference done. Logits shape:", logits.shape)
print("Aligned rows:", len(df_infer))
print("Skipped:", len(skipped))
if skipped:
    print("Skipped examples (first 5):", skipped[:5])

# Top-1 prediction and probabilities
pred_idx   = logits.argmax(axis=1)
pred_label = [user_class_names[i] for i in pred_idx]
probs      = 1.0 / (1.0 + np.exp(-logits))

# Quick preview
display(df_infer.assign(pred_label=pred_label).head(10))

100%|██████████| 4/4 [00:07<00:00,  1.90s/it]

Inference done. Logits shape: (63, 5)
Aligned rows: 63
Skipped: 0





Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,copied_path,img_path,labels_list,labels,label,pred_label
0,3789,3789_IM-1902-2001.dcm.png,Lateral,Cardiomegaly/mild;Surgical Instruments/abdomen...,Cardiomegaly;Surgical Instruments,"Chest, 2 views, XXXX XXXX",XXXX,XXXX,Mild cardiomegaly is unchanged. Stable superio...,Stable appearance of the chest. No acute cardi...,./indiana_sampled_500_images/3789_IM-1902-2001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Cardiomegaly],Cardiomegaly,Cardiomegaly,Edema
1,3142,3142_IM-1477-1001.dcm.png,Lateral,Cardiomegaly/mild;Aorta/tortuous,Cardiomegaly;Aorta,Frontal and lateral views of the chest was ob...,possible seizure,,There is stable mild cardiomegaly without sign...,No acute process. Stable cardiomegaly.,./indiana_sampled_500_images/3142_IM-1477-1001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Cardiomegaly],Cardiomegaly,Cardiomegaly,Cardiomegaly
2,3863,3863_IM-1957-2001.dcm.png,Lateral,Cardiac Shadow/enlarged;Cardiomegaly,Cardiac Shadow;Cardiomegaly,"PA and lateral views of the chest XXXX, XXXX a...",XXXX-year-old XXXX with XXXX.,"XXXX, XXXX.","The lungs are clear, and without focal airspac...","Cardiomegaly, but no focal consolidation.",./indiana_sampled_500_images/3863_IM-1957-2001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Cardiomegaly],Cardiomegaly,Cardiomegaly,Cardiomegaly
3,665,665_IM-2240-1001.dcm.png,Frontal,Cardiomegaly,Cardiomegaly,"PA and lateral chest XXXX, XXXX at XXXX compar...",History of chest pain,,,Cardiomegaly stable. Lungs clear. No edema or ...,./indiana_sampled_500_images/665_IM-2240-1001....,/Users/zitongluo/Library/Mobile Documents/com~...,[Cardiomegaly],Cardiomegaly,Cardiomegaly,Cardiomegaly
4,2286,2286_IM-0871-1001.dcm.png,Frontal,Opacity/multiple/chronic;Emphysema,Opacity;Emphysema,"Chest x-XXXX, 2 views dated XXXX COMPARISXXXX/...",XXXX-year-old male with weakness,,Chronic-appearing XXXX opacities are unchanged...,No acute cardiopulmonary abnormalities.,./indiana_sampled_500_images/2286_IM-0871-1001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Consolidation],Consolidation,Consolidation,Atelectasis
5,2859,2859_IM-1266-1001.dcm.png,Frontal,Opacity/lung/upper lobe/left/round;Density/lun...,Opacity;Density;Pneumonia,Xray Chest PA and Lateral,Pain and difficulty breathing.,,There is a rounded dense opacity in the latera...,Opacity XXXX representing left upper lobe pneu...,./indiana_sampled_500_images/2859_IM-1266-1001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Consolidation],Consolidation,Consolidation,Consolidation
6,408,408_IM-2054-1001.dcm.png,Frontal,Lung/hypoinflation;Pleural Effusion/bilateral/...,Lung;Pleural Effusion,Xray Chest PA and Lateral,,None Indication Cirrhosis Evaluate pre liver t...,The heart is normal in size. The mediastinum i...,Small bilateral pleural effusions.,./indiana_sampled_500_images/408_IM-2054-1001....,/Users/zitongluo/Library/Mobile Documents/com~...,[Pleural Effusion],Pleural Effusion,Pleural Effusion,Pleural Effusion
7,619,619_IM-2202-1002.dcm.png,Frontal,Aorta/tortuous;Cardiomegaly,Aorta;Cardiomegaly,PA and lateral chest radiographs XXXX at XXXX ...,XXXX-year-old female with right-sided chest pa...,PA and lateral chest radiographs XXXX.,There has been interval sternotomy with intact...,"Cardiomegaly, however no acute cardiopulmonary...",./indiana_sampled_500_images/619_IM-2202-1002....,/Users/zitongluo/Library/Mobile Documents/com~...,[Cardiomegaly],Cardiomegaly,Cardiomegaly,Cardiomegaly
8,2318,2318_IM-0891-1001.dcm.png,Frontal,Lucency/thorax/left;Pneumothorax/left/large;Pu...,Lucency;Pneumothorax;Pulmonary Atelectasis,"CHEST 2V FRONTAL/LATERAL RADXXXX XXXX, XXXX XX...",XXXX,,,No comparison chest x-XXXX XXXX lungs. Lucency...,./indiana_sampled_500_images/2318_IM-0891-1001...,/Users/zitongluo/Library/Mobile Documents/com~...,[Atelectasis],Atelectasis,Atelectasis,Pleural Effusion
9,268,268_IM-1153-1001.dcm.png,Frontal,Opacity/lung/upper lobe/right;Pneumonia/upper ...,Opacity;Pneumonia,"Chest radiograph, frontal and lateral views",,,There is a right upper lobe opacity. Cardiomed...,Right upper lobe pneumonia.,./indiana_sampled_500_images/268_IM-1153-1001....,/Users/zitongluo/Library/Mobile Documents/com~...,[Consolidation],Consolidation,Consolidation,Consolidation


In [53]:
# ========= 8) Metrics (single-label, disease-only) — no file export =========
# Standard Top-1 accuracy on single-label diseased set:
y_true  = df_infer["label"].astype(str).tolist()
y_pred  = pred_label
mask    = [yt.strip() != "" for yt in y_true]
acc     = accuracy_score([y_true[i] for i in range(len(y_true)) if mask[i]],
                         [y_pred[i] for i in range(len(y_pred)) if mask[i]]) if any(mask) else float("nan")
print(f"\nTop-1 Accuracy (disease-only, single-label): {acc:.4f} (valid={sum(mask)}/{len(mask)})")

# Per-class Top-1 accuracy (among samples whose true label == that class)
per_class = {}
for c in CHEXPERT5:
    idxs = [i for i, yt in enumerate(y_true) if mask[i] and yt == c]
    if len(idxs) == 0:
        per_class[c] = {"support": 0, "acc": None}
    else:
        correct = sum(1 for i in idxs if y_pred[i] == c)
        per_class[c] = {"support": len(idxs), "acc": correct / len(idxs)}

print("\nPer-class Top-1 accuracy (disease-only):")
for k, v in per_class.items():
    acc_str = "NA" if v["acc"] is None else f"{v['acc']:.3f}"
    print(f"{k:20s} | support={v['support']:4d} | acc={acc_str}")

# Optional: show a small probability table preview (no saving)
prob_cols = [f"prob_{c.replace(' ', '_')}" for c in CHEXPERT5]
df_preview = df_infer.copy()
df_preview["pred_label"] = y_pred
for i, c in enumerate(prob_cols):
    df_preview[c] = probs[:, i]

try:
    from IPython.display import display
    display(df_preview[["filename", "label", "pred_label"] + prob_cols].head(10))
except Exception:
    print(df_preview[["filename", "label", "pred_label"] + prob_cols].head(10))


Top-1 Accuracy (disease-only, single-label): 0.3810 (valid=63/63)

Per-class Top-1 accuracy (disease-only):
Atelectasis          | support=  12 | acc=0.250
Cardiomegaly         | support=  19 | acc=0.737
Consolidation        | support=  27 | acc=0.111
Edema                | support=   1 | acc=0.000
Pleural Effusion     | support=   4 | acc=1.000


Unnamed: 0,filename,label,pred_label,prob_Atelectasis,prob_Cardiomegaly,prob_Consolidation,prob_Edema,prob_Pleural_Effusion
0,3789_IM-1902-2001.dcm.png,Cardiomegaly,Edema,0.569164,0.64499,0.558347,0.669012,0.530014
1,3142_IM-1477-1001.dcm.png,Cardiomegaly,Cardiomegaly,0.558651,0.576304,0.551623,0.568198,0.536779
2,3863_IM-1957-2001.dcm.png,Cardiomegaly,Cardiomegaly,0.561605,0.573848,0.545676,0.552219,0.532259
3,665_IM-2240-1001.dcm.png,Cardiomegaly,Cardiomegaly,0.574959,0.682429,0.551513,0.574511,0.594878
4,2286_IM-0871-1001.dcm.png,Consolidation,Atelectasis,0.578988,0.562414,0.553803,0.547224,0.573786
5,2859_IM-1266-1001.dcm.png,Consolidation,Consolidation,0.524626,0.544463,0.681051,0.541836,0.548886
6,408_IM-2054-1001.dcm.png,Pleural Effusion,Pleural Effusion,0.596465,0.545307,0.552254,0.540983,0.757146
7,619_IM-2202-1002.dcm.png,Cardiomegaly,Cardiomegaly,0.551233,0.577705,0.546647,0.553537,0.557296
8,2318_IM-0891-1001.dcm.png,Atelectasis,Pleural Effusion,0.576367,0.534423,0.539693,0.52554,0.626853
9,268_IM-1153-1001.dcm.png,Consolidation,Consolidation,0.536138,0.541632,0.679257,0.55445,0.538622


# NIH

In [1]:
import pandas as pd

# ==== 配置 ====
INPUT_CSV = "local_data/nih-sampled-meta.csv"   # 你的源文件
DISEASE_COL = "disease"                        # 分组列名
N_TRAIN_PER_CLASS = 400                        # 每类训练集数量
RANDOM_STATE = 42                              # 随机种子（可改）

# ==== 读取 ====
df = pd.read_csv(INPUT_CSV)

# 基本检查（五类、每类1400）
counts = df[DISEASE_COL].value_counts()
assert len(counts) == 5, f"发现 {len(counts)} 个类别，期望 5 个。实际统计：\n{counts}"
assert (counts == 1400).all(), f"每类样本数不是 1400：\n{counts}"

# ==== 分层随机抽样 ====
# 每个疾病抽 N_TRAIN_PER_CLASS 条作为训练集
train_df = (
    df.groupby(DISEASE_COL, group_keys=False)
      .sample(n=N_TRAIN_PER_CLASS, random_state=RANDOM_STATE)
)

# 测试集为剩余样本
test_df = df.drop(train_df.index)

# 可选：整体打乱行顺序
train_df = train_df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)
test_df  = test_df.sample(frac=1, random_state=RANDOM_STATE).reset_index(drop=True)

# ==== 校验 ====
train_counts = train_df[DISEASE_COL].value_counts().sort_index()
test_counts  = test_df[DISEASE_COL].value_counts().sort_index()
assert (train_counts == N_TRAIN_PER_CLASS).all(), f"训练集每类应为 {N_TRAIN_PER_CLASS}：\n{train_counts}"
assert (test_counts == (counts.iloc[0] - N_TRAIN_PER_CLASS)).all(), f"测试集每类应为 {counts.iloc[0]-N_TRAIN_PER_CLASS}：\n{test_counts}"

print("Train per class:\n", train_counts)
print("Test  per class:\n", test_counts)
print(f"Train total: {len(train_df)}  |  Test total: {len(test_df)}")

# ==== 导出 ====
train_df.to_csv("local_data/nih-sampled-meta-train.csv", index=False)
test_df.to_csv("local_data/nih-sampled-meta-test.csv", index=False)
print("Saved to local_data")

Train per class:
 disease
Atelectasis         400
Cardiomegaly        400
Consolidation       400
Edema               400
Pleural Effusion    400
Name: count, dtype: int64
Test  per class:
 disease
Atelectasis         1000
Cardiomegaly        1000
Consolidation       1000
Edema               1000
Pleural Effusion    1000
Name: count, dtype: int64
Train total: 2000  |  Test total: 5000
Saved to local_data


In [10]:
# ========= 1) Setup & Config (MPS first, disease-only) =========
import re
from pathlib import Path
from collections import OrderedDict

import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
from sklearn.metrics import accuracy_score

# MedCLIP
from medclip import MedCLIPModel, MedCLIPProcessor
from medclip import PromptClassifier
try:
    from medclip.prompts import generate_chexpert_class_prompts, process_class_prompts
    HAS_PROMPT_UTILS = True
except Exception:
    HAS_PROMPT_UTILS = False

# --- Device (prefer MPS on Apple Silicon) ---
if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

# --- Paths ---
THIS_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd().resolve()
CSV_PATH = (THIS_DIR / "local_data" / "nih-sampled-meta-test.csv").resolve()  # <- set to your nih-sampled-meta.csv
IMAGE_ROOT = (THIS_DIR / "data" / "nih").resolve()          # <- set to the folder containing NIH images

# --- MedCLIP & eval config ---
VISION_MODEL = "vit"         # "vit" or "resnet"
BATCH_SIZE = 16
NUM_PROMPTS_PER_CLASS = 10
CHEXPERT5 = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]  # disease-only

print("Using device:", DEVICE)
print("CSV_PATH:", CSV_PATH)
print("IMAGE_ROOT:", IMAGE_ROOT)

Using device: mps
CSV_PATH: /Users/zitongluo/Library/Mobile Documents/com~apple~CloudDocs/硕士相关/2025Fall/Learning from small data/MedCLIP_eval/local_data/nih-sampled-meta-test.csv
IMAGE_ROOT: /Users/zitongluo/.cache/kagglehub/datasets/nih-chest-xrays/data/versions/3


In [11]:
# ========= 2) Load NIH CSV & Resolve Sharded Image Paths =========
from IPython.display import display
from pathlib import Path
import os

df = pd.read_csv(CSV_PATH)

# Expect the columns: 'Image Index' (filename) and 'disease' (your 5-way label)
req_cols = {"Image Index"}
missing = req_cols - set(df.columns)
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

# Normalize a 'filename' column
df["filename"] = df["Image Index"].astype(str).str.strip()

# --- Helper: build a mapping {filename_lower -> full_path} by scanning shards ---
def index_sharded_paths(root: Path, wanted_names):
    """
    Resolve NIH paths when images are sharded under:
        root / images_001/images/*.png
        root / images_002/images/*.png
        ...
    Strategy:
      1) Fast pass: iterate each shard directory once, index files we need.
      2) Fallback for any leftovers: recursive rglob by filename (slower but robust).
    Returns:
      path_map_lower: dict[str_lower -> str_fullpath]
      not_found: set of original names that were not resolved
    """
    root = Path(root)
    wanted_lower = {name.lower(): name for name in set(wanted_names)}
    remaining = set(wanted_lower.keys())
    path_map_lower = {}

    # 1) Fast pass through shards
    for shard in sorted(root.glob("images_*")):
        img_dir = shard / "images"
        if not img_dir.is_dir():
            continue
        # Iterate files inside this shard
        for p in img_dir.iterdir():
            if not p.is_file():
                continue
            key = p.name.lower()
            if key in remaining:
                path_map_lower[key] = str(p.resolve())
                remaining.remove(key)
        if not remaining:
            break  # all found

    # 2) Fallback: recursive search for still-missing files
    #    Only run if something is still missing (can be slow)
    if remaining:
        for key in list(remaining):
            # rglob by exact filename
            matches = list(root.rglob(wanted_lower[key]))
            if matches:
                path_map_lower[key] = str(matches[0].resolve())
                remaining.remove(key)

    # Map back to original names for reporting
    not_found = {wanted_lower[k] for k in remaining}
    return path_map_lower, not_found

# Build mapping for all unique filenames we need
path_map_lower, not_found = index_sharded_paths(IMAGE_ROOT, df["filename"].tolist())

# Attach resolved paths to df
df["img_path"] = df["filename"].str.lower().map(path_map_lower)

# Keep only existing-resolved images
exist_mask = df["img_path"].map(
    lambda p: isinstance(p, (str, os.PathLike)) and len(str(p)) > 0 and Path(p).exists()
)
print(f"[Path check] {exist_mask.sum()}/{len(df)} files resolved & exist; missing={(~exist_mask).sum()}")
if (~exist_mask).any():
    # Show a few examples that failed to resolve
    display(df.loc[~exist_mask, ["filename"]].head(10))
    if not_found:
        print("Examples not found (first 10):", list(sorted(not_found))[:10])

df = df.loc[exist_mask].reset_index(drop=True)
print("Using images:", len(df))

[Path check] 5000/5000 files resolved & exist; missing=0
Using images: 5000


In [5]:
# ========= 3) Use `disease` as the single label (disease-only filtering) =========
# Clean label strings
df["label"] = df["disease"].astype(str).str.strip()

# Keep disease-only (must be one of CHEXPERT5; drop anything else, e.g., 'No Finding')
before = len(df)
df = df[df["label"].isin(CHEXPERT5)].reset_index(drop=True)
after = len(df)
print(f"[Filter] Kept disease-only rows: {after}/{before} (dropped {before - after}).")

# Sanity check: distribution
print("\nLabel distribution:")
print(df["label"].value_counts().reindex(CHEXPERT5, fill_value=0))

[Filter] Kept disease-only rows: 5000/5000 (dropped 0).

Label distribution:
label
Atelectasis         1000
Cardiomegaly        1000
Consolidation       1000
Edema               1000
Pleural Effusion    1000
Name: count, dtype: int64


In [6]:
# ========= 4) Load MedCLIP (model + classifier) =========
processor = MedCLIPProcessor()
try:
    model = MedCLIPModel.from_pretrained(vision_model=VISION_MODEL, device=DEVICE)
except TypeError:
    name_map = {"vit": "medclip-vit", "resnet": "medclip-resnet50"}
    model = MedCLIPModel.from_pretrained(name_map.get(VISION_MODEL.lower(), VISION_MODEL))
model = model.to(DEVICE)
model.eval()

clf = PromptClassifier(model, ensemble=True).to(DEVICE)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relatio

Model moved to mps
load model weight from: pretrained/medclip-vit


In [7]:
# ========= 5) Build prompts (CheXpert-5 only; no Normal) =========
# Keep prompts as dict: class_name -> tokenized tensors (required by PromptClassifier)

if HAS_PROMPT_UTILS:
    cls_prompts_5 = process_class_prompts(generate_chexpert_class_prompts(n=NUM_PROMPTS_PER_CLASS))
    if not isinstance(cls_prompts_5, dict):
        raise TypeError(f"process_class_prompts should return dict, got {type(cls_prompts_5)}")
else:
    base = [
        "A chest X-ray showing {label}.",
        "The radiograph demonstrates {label}.",
        "CXR with finding: {label}.",
        "This image indicates {label}.",
    ]
    cls_prompts_5 = {}
    for lbl in CHEXPERT5:
        texts = [t.format(label=lbl) for t in base]
        cls_prompts_5[lbl] = processor(text=texts, return_tensors="pt", padding=True)

# Ordered mapping with stable class order (5 classes)
prompt_inputs = OrderedDict()
for c in CHEXPERT5:
    if c not in cls_prompts_5:
        raise KeyError(f"Missing prompts for class: {c}")
    prompt_inputs[c] = cls_prompts_5[c]

user_class_names = list(prompt_inputs.keys())  # == CHEXPERT5
print("Prompts built for classes:", user_class_names)
print("Per-class prompt counts:", [prompt_inputs[c]["input_ids"].shape[0] for c in user_class_names])

sample 10 num of prompts for Atelectasis from total 210
sample 10 num of prompts for Cardiomegaly from total 15
sample 10 num of prompts for Consolidation from total 192
sample 10 num of prompts for Edema from total 18
sample 10 num of prompts for Pleural Effusion from total 54
Prompts built for classes: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
Per-class prompt counts: [10, 10, 10, 10, 10]


In [8]:
# ========= 6) Inference (Top-1 over 5 classes; no export) =========
from IPython.display import display

def batched(seq, n):
    for i in range(0, len(seq), n):
        yield seq[i:i+n], list(range(i, min(i+n, len(seq))))

img_paths = df["img_path"].tolist()
all_logits = []
processed_indices = []
skipped = []

total_batches = (len(img_paths) + BATCH_SIZE - 1) // BATCH_SIZE
for batch_files, idxs in tqdm(batched(img_paths, BATCH_SIZE), total=total_batches):
    images = []
    kept_local = []
    for j, p in enumerate(batch_files):
        try:
            with Image.open(p) as im:
                images.append(im.convert("RGB"))
            kept_local.append(idxs[j])
        except Exception as e:
            skipped.append((p, repr(e)))

    if not images:
        continue

    inputs = processor(images=images, return_tensors="pt")
    inputs["prompt_inputs"] = prompt_inputs
    inputs = {k: (v.to(DEVICE) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}

    with torch.no_grad():
        out = clf(**inputs)
        batch_logits = out["logits"].detach().cpu().numpy()  # shape: [b, 5]

    if 'CLASS_NAMES' not in globals():
        CLASS_NAMES = out.get("class_names", CHEXPERT5)  # FYI
        print("Model's class_names (5):", CLASS_NAMES)

    all_logits.append(batch_logits)
    processed_indices.extend(kept_local)

if len(all_logits) == 0:
    raise RuntimeError("No images were processed. Check IMAGE_ROOT/filenames.")

logits = np.concatenate(all_logits, axis=0)   # [N_used, 5]
df_infer = df.iloc[processed_indices].reset_index(drop=True)

print("Inference done. Logits shape:", logits.shape)
print("Aligned rows:", len(df_infer))
print("Skipped:", len(skipped))
if skipped:
    print("Skipped examples (first 5):", skipped[:5])

# Top-1 prediction and probs
pred_idx   = logits.argmax(axis=1)
pred_label = [user_class_names[i] for i in pred_idx]
probs      = 1.0 / (1.0 + np.exp(-logits))

# Quick preview
display(df_infer.assign(pred_label=pred_label).head(10))

  0%|          | 0/313 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 1/313 [00:01<09:40,  1.86s/it]

Model's class_names (5): ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']


100%|██████████| 313/313 [06:58<00:00,  1.34s/it]

Inference done. Logits shape: (5000, 5)
Aligned rows: 5000
Skipped: 0





Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,...,imgpath,Edema,Atelectasis,Pleural Effusion,Cardiomegaly,Consolidation,filename,img_path,label,pred_label
0,00007321_013.png,Effusion|Nodule,13,7321,58,M,PA,2500,2048,0.168,...,data/nih/images_004/images/00007321_013.png,0,0,1,0,0,00007321_013.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Pleural Effusion
1,00021035_015.png,Consolidation|Infiltration|Mass|Nodule,15,21035,34,M,AP,3056,2544,0.139,...,data/nih/images_010/images/00021035_015.png,0,0,0,0,1,00021035_015.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Consolidation,Consolidation
2,00026848_028.png,Edema|Infiltration|Pneumonia,28,26848,22,M,AP,3056,2544,0.139,...,data/nih/images_011/images/00026848_028.png,1,0,0,0,0,00026848_028.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Edema,Edema
3,00000402_010.png,Effusion|Fibrosis,10,402,61,M,AP,2500,2048,0.168,...,data/nih/images_001/images/00000402_010.png,0,0,1,0,0,00000402_010.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Pleural Effusion
4,00015286_001.png,Cardiomegaly,1,15286,54,F,PA,2862,2709,0.143,...,data/nih/images_007/images/00015286_001.png,0,0,0,1,0,00015286_001.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Cardiomegaly,Cardiomegaly
5,00013615_054.png,Cardiomegaly|Infiltration,54,13615,11,F,AP,2500,2048,0.168,...,data/nih/images_006/images/00013615_054.png,0,0,0,1,0,00013615_054.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Cardiomegaly,Pleural Effusion
6,00019694_001.png,Atelectasis,1,19694,62,F,PA,1941,2021,0.194311,...,data/nih/images_009/images/00019694_001.png,0,1,0,0,0,00019694_001.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Atelectasis,Atelectasis
7,00006819_002.png,Consolidation|Nodule,2,6819,64,F,PA,2048,2500,0.171,...,data/nih/images_004/images/00006819_002.png,0,0,0,0,1,00006819_002.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Consolidation,Pleural Effusion
8,00017606_027.png,Effusion,27,17606,56,M,AP,2500,2048,0.168,...,data/nih/images_008/images/00017606_027.png,0,0,1,0,0,00017606_027.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Pleural Effusion,Pleural Effusion
9,00014785_006.png,Edema,6,14785,20,M,AP,2500,2048,0.168,...,data/nih/images_007/images/00014785_006.png,1,0,0,0,0,00014785_006.png,/Users/zitongluo/.cache/kagglehub/datasets/nih...,Edema,Edema


In [9]:
# ========= 7) Metrics (single-label, disease-only; no export) =========
y_true  = df_infer["label"].astype(str).tolist()
y_pred  = pred_label
mask    = [yt.strip() != "" for yt in y_true]

acc = accuracy_score([y_true[i] for i in range(len(y_true)) if mask[i]],
                     [y_pred[i] for i in range(len(y_pred)) if mask[i]]) if any(mask) else float("nan")
print(f"\nTop-1 Accuracy (disease-only): {acc:.4f} (valid={sum(mask)}/{len(mask)})")

# Per-class Top-1 accuracy
per_class = {}
for c in CHEXPERT5:
    idxs = [i for i, yt in enumerate(y_true) if mask[i] and yt == c]
    if len(idxs) == 0:
        per_class[c] = {"support": 0, "acc": None}
    else:
        correct = sum(1 for i in idxs if y_pred[i] == c)
        per_class[c] = {"support": len(idxs), "acc": correct / len(idxs)}

print("\nPer-class Top-1 accuracy:")
for k, v in per_class.items():
    acc_str = "NA" if v["acc"] is None else f"{v['acc']:.3f}"
    print(f"{k:20s} | support={v['support']:4d} | acc={acc_str}")

# Optional: small probability table preview
prob_cols = [f"prob_{c.replace(' ', '_')}" for c in CHEXPERT5]
df_preview = df_infer.copy()
df_preview["pred_label"] = y_pred
for i, c in enumerate(prob_cols):
    df_preview[c] = probs[:, i]

try:
    from IPython.display import display
    display(df_preview[["filename", "label", "pred_label"] + prob_cols].head(10))
except Exception:
    print(df_preview[["filename", "label", "pred_label"] + prob_cols].head(10))


Top-1 Accuracy (disease-only): 0.5286 (valid=5000/5000)

Per-class Top-1 accuracy:
Atelectasis          | support=1000 | acc=0.514
Cardiomegaly         | support=1000 | acc=0.610
Consolidation        | support=1000 | acc=0.219
Edema                | support=1000 | acc=0.520
Pleural Effusion     | support=1000 | acc=0.780


Unnamed: 0,filename,label,pred_label,prob_Atelectasis,prob_Cardiomegaly,prob_Consolidation,prob_Edema,prob_Pleural_Effusion
0,00007321_013.png,Pleural Effusion,Pleural Effusion,0.680292,0.553198,0.605651,0.546557,0.691927
1,00021035_015.png,Consolidation,Consolidation,0.529955,0.528914,0.692851,0.536791,0.615774
2,00026848_028.png,Edema,Edema,0.5344,0.559519,0.618772,0.633023,0.603653
3,00000402_010.png,Pleural Effusion,Pleural Effusion,0.50272,0.529124,0.581051,0.550608,0.654939
4,00015286_001.png,Cardiomegaly,Cardiomegaly,0.549851,0.740179,0.550209,0.632719,0.535289
5,00013615_054.png,Cardiomegaly,Pleural Effusion,0.583479,0.598634,0.647083,0.680046,0.693352
6,00019694_001.png,Atelectasis,Atelectasis,0.601996,0.591236,0.550058,0.553124,0.54968
7,00006819_002.png,Consolidation,Pleural Effusion,0.568555,0.537909,0.662748,0.54956,0.768223
8,00017606_027.png,Pleural Effusion,Pleural Effusion,0.643161,0.552168,0.580876,0.601246,0.753519
9,00014785_006.png,Edema,Edema,0.540597,0.538873,0.556977,0.601053,0.556809
