In [83]:
import os
import pandas as pd
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader

from medclip import MedCLIPModel, MedCLIPProcessor, PromptClassifier
from medclip.prompts import generate_chexpert_class_prompts
from medclip.dataset import ZeroShotImageDataset, ZeroShotImageCollator
from medclip.evaluator import Evaluator

from notebooks.Beisong_prompts import my_prompts

In [3]:
path = "/Users/liubeisong/Desktop/2025_Fall/Small_Data/data"

In [4]:
IMG_PATH = os.path.join(path, "images/images_normalized")
PROJECTIONS_PATH = os.path.join(path, "indiana_projections.csv")
REPORTS_PATH = os.path.join(path, "indiana_reports.csv")

In [5]:
df_proj = pd.read_csv(PROJECTIONS_PATH)
df_rep = pd.read_csv(REPORTS_PATH)

In [6]:
df_merged = pd.merge(df_proj, df_rep, on="uid", how="inner")

In [7]:
df_merged.columns

Index(['uid', 'filename', 'projection', 'MeSH', 'Problems', 'image',
       'indication', 'comparison', 'findings', 'impression'],
      dtype='object')

In [8]:
print(df_merged["uid"].nunique(), "unique patients", len(df_merged), "total rows")
print(df_merged["filename"].nunique(), "unique images")

3851 unique patients 7466 total rows
7466 unique images


In [9]:
# now we process the problems column, we can split them into a list
# also remove duplicates
SEP_PATTERN = r"[|,;/]"
df_merged["Problem_List"] = (
    df_merged["Problems"]
    .str.lower()
    .str.split(SEP_PATTERN)
    .apply(lambda lst: list(dict.fromkeys([x.strip() for x in lst if x.strip()])))
)

In [24]:
CLASSES = ['atelectasis', 'cardiomegaly', 'consolidation', 'edema', 'pleural effusion']
df_merged["disease_count"] = df_merged["Problem_List"].apply(
    lambda lst: sum(
        any(disease in x for x in lst) for disease in CLASSES
    )
)

In [25]:
single_disease_df = df_merged[df_merged["disease_count"] == 1].copy()

In [26]:
single_disease_df["Disease"] = single_disease_df["Problem_List"].apply(
    lambda lst: next((d for d in CLASSES if any(d in x for x in lst)), None)
)

In [27]:
single_disease_df["Disease"].value_counts()

Disease
cardiomegaly        442
atelectasis         433
pleural effusion    117
consolidation        23
edema                12
Name: count, dtype: int64

In [28]:
single_disease_df.to_csv("single_disease_df.csv", index=False)

In [75]:
### Multiclass zero-shot (5 CheXpert classes)

MAX_PER_CLASS = 50

# Map to paths and class names (lowercase to match prompts)
df_mc = single_disease_df.copy()
df_mc['Disease'] = df_mc['Disease'].str.lower()
df_mc['imgpath'] = df_mc['filename'].map(lambda f: f if os.path.isabs(f) else os.path.join(IMG_PATH, f))

# Cap per-class to at most 50 (take all if < 50)
df_bal = (
    df_mc.groupby('Disease', group_keys=False)
         .apply(lambda g: g.sample(n=min(len(g), MAX_PER_CLASS), random_state=42))
         .reset_index(drop=True)
)

# One-hot in CLASSES order
for c in CLASSES:
    df_bal[c] = (df_bal['Disease'] == c).astype(int)

meta_mc = df_bal[['imgpath'] + CLASSES]

# Save for ZeroShotImageDataset
out_csv = Path('local_data/chexpert-multiclass-capped-meta.csv')
out_csv.parent.mkdir(exist_ok=True)
meta_mc.to_csv(out_csv)

print("Per-class counts after capping:")
print(df_bal['Disease'].value_counts().reindex(CLASSES))
print(f"Saved to: {out_csv} | shape={meta_mc.shape}")


Per-class counts after capping:
Disease
atelectasis         50
cardiomegaly        50
consolidation       23
edema               12
pleural effusion    50
Name: count, dtype: int64
Saved to: local_data/chexpert-multiclass-capped-meta.csv | shape=(185, 6)


  .apply(lambda g: g.sample(n=min(len(g), MAX_PER_CLASS), random_state=42))


In [93]:
# 2) Generate prompts for five CheXpert classes (avoid disease words in negatives)
chex_prompts = generate_chexpert_class_prompts(n=8)

# Build lowercase -> positive prompt mapping
cls_prompts_mc = {c: chex_prompts[c.title()] for c in CLASSES}


print("Prompt counts per class:")
print({k: len(v) for k, v in cls_prompts_mc.items()})


sample 8 num of prompts for Atelectasis from total 210
sample 8 num of prompts for Cardiomegaly from total 15
sample 8 num of prompts for Consolidation from total 192
sample 8 num of prompts for Edema from total 18
sample 8 num of prompts for Pleural Effusion from total 54
Prompt counts per class:
{'atelectasis': 8, 'cardiomegaly': 8, 'consolidation': 8, 'edema': 8, 'pleural effusion': 8}


In [84]:
print("Prompt counts per class:")
print({k: len(v) for k, v in my_prompts.items()})

Prompt counts per class:
{'atelectasis': 5, 'cardiomegaly': 5, 'consolidation': 5, 'edema': 5, 'pleural effusion': 5}


In [90]:
# 3) Create multiclass dataset/loader

mc_dataset = ZeroShotImageDataset(
    datalist=['chexpert-multiclass-capped'],
    class_names=CLASSES
)

mc_collator = ZeroShotImageCollator(
    mode='multiclass',
    cls_prompts=cls_prompts_mc
)

mc_loader = DataLoader(
    mc_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=mc_collator,
    num_workers=2
)

print('Batches:', len(mc_loader))
sample = next(iter(mc_loader))
print(sample['pixel_values'].shape, sample['labels'].shape)
print('Prompt keys:', list(sample['prompt_inputs'].keys()))


load data from ./local_data/chexpert-multiclass-capped-meta.csv




Batches: 6
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
torch.Size([32, 3, 224, 224]) torch.Size([32])
Prompt keys: ['atelectasis', 'cardiomegaly', 'consolidation', 'edema', 'pleural effusion']


In [94]:
# 4) Init model + PromptClassifier (ensemble)
processor = MedCLIPProcessor()
model = MedCLIPModel.from_pretrained(vision_model='vit', device='cpu')
clf_mc = PromptClassifier(model, ensemble=True).to('cpu')
# clf_mc.to('mps')
clf_mc.eval()
print('Multiclass classifier ready')


Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.se

Model moved to cpu
load model weight from: pretrained/medclip-vit
Multiclass classifier ready


In [95]:
# %% # 5) Evaluate (multiclass) 
evaluator_mc = Evaluator( medclip_clf=clf_mc, eval_dataloader=mc_loader, mode='multiclass' ) 
results_mc = evaluator_mc.evaluate() 
print("\n" + "="*50) 
print("EVALUATION RESULTS (Multiclass - 5 CheXpert)") 
print("="*50)
for metric, value in results_mc.items(): 
    if metric not in ['pred', 'labels']: 
        if isinstance(value, float): print(f"{metric:20s}: {value:.4f}") 
        else: print(f"{metric:20s}: {value}")

Evaluation:   0%|          | 0/6 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Evaluation: 100%|██████████| 6/6 [00:48<00:00,  8.11s/it]


EVALUATION RESULTS (Multiclass - 5 CheXpert)
acc                 : 0.6108
precision           : 0.6084
recall              : 0.5243
f1-score            : 0.5414





In [96]:
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score

CLASSES = ['atelectasis','cardiomegaly','consolidation','edema','pleural effusion']

# --- logits -> probs & predictions ---
logits = torch.tensor(results_mc['pred'])          # [N, C]
probs  = torch.softmax(logits, dim=1).numpy()      # [N, C]
y_pred = probs.argmax(axis=1)                      # [N]

# --- labels: accept [N] indices or [N,C] one-hot ---
labels = np.array(results_mc['labels'])
y_true = labels if labels.ndim == 1 else labels.argmax(axis=1)

# --- per-class precision/recall/F1/support ---
rep = classification_report(
    y_true, y_pred, target_names=CLASSES, digits=4, output_dict=True
)
per_class_df = (
    pd.DataFrame(rep).T
      .loc[CLASSES, ['precision','recall','f1-score','support']]
      .reset_index().rename(columns={'index':'class'})
)

# --- per-class AUC (OvR); NaN if a class has no positives ---
y_true_oh = np.zeros_like(probs)
y_true_oh[np.arange(len(y_true)), y_true] = 1
try:
    auc_per_class = roc_auc_score(y_true_oh, probs, multi_class='ovr', average=None)
except Exception:
    auc_per_class = np.array([np.nan]*len(CLASSES))
auc_df = pd.DataFrame({'class': CLASSES, 'auc_ovr': auc_per_class})

# --- correct counts (TP per class) from confusion matrix diagonal ---
cm = confusion_matrix(y_true, y_pred, labels=range(len(CLASSES)))
correct_counts = np.diag(cm)

# --- combine everything ---
metrics_df = (
    per_class_df
      .merge(auc_df, on='class', how='left')
      .assign(correct=correct_counts)
      .loc[:, ['class','precision','recall','f1-score','support','correct','auc_ovr']]
      .sort_values('f1-score', ascending=False)
      .reset_index(drop=True)
)

# --- overall summaries ---
overall = {
    'accuracy': accuracy_score(y_true, y_pred),
    'macro_f1': rep['macro avg']['f1-score'],
    'weighted_f1': rep['weighted avg']['f1-score']
}

print("Overall:", {k: round(v, 4) for k, v in overall.items()})
print("\nPer-class metrics:")
print(metrics_df.to_string(index=False))


Overall: {'accuracy': 0.6108, 'macro_f1': 0.5414, 'weighted_f1': 0.6075}

Per-class metrics:
           class  precision   recall  f1-score  support  correct  auc_ovr
pleural effusion   0.660377 0.700000  0.679612     50.0       35 0.817630
    cardiomegaly   0.569231 0.740000  0.643478     50.0       37 0.886370
   consolidation   1.000000 0.434783  0.606061     23.0       10 0.745303
     atelectasis   0.630435 0.580000  0.604167     50.0       29 0.804444
           edema   0.181818 0.166667  0.173913     12.0        2 0.750482


In [None]:
import numpy as np
import torch
import pandas as pd
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score, accuracy_score
)

CLASSES = ['atelectasis','cardiomegaly','consolidation','edema','pleural effusion']

# --- Get predictions & labels ---
logits = torch.tensor(results_mc['pred'])           # [N, C]
probs  = torch.softmax(logits, dim=1).numpy()       # [N, C]
y_pred = probs.argmax(axis=1)                       # [N]

labels = np.array(results_mc['labels'])
if labels.ndim == 2:
    y_true = labels.argmax(axis=1)
else:
    y_true = labels

# --- Per-class precision/recall/F1/support ---
report_dict = classification_report(
    y_true, y_pred, target_names=CLASSES, digits=4, output_dict=True
)
per_class_df = (
    pd.DataFrame(report_dict)
      .T.loc[CLASSES, ['precision','recall','f1-score','support']]
      .reset_index()
      .rename(columns={'index':'class'})
)

# --- Overall & macro/micro summaries (optional) ---
overall = {
    'accuracy': accuracy_score(y_true, y_pred),
    'macro_f1': report_dict['macro avg']['f1-score'],
    'weighted_f1': report_dict['weighted avg']['f1-score']
}
overall

# --- Confusion matrix (counts & row-normalized) ---
cm = confusion_matrix(y_true, y_pred, labels=range(len(CLASSES)))
cm_df = pd.DataFrame(cm, index=[f"true:{c}" for c in CLASSES],
                        columns=[f"pred:{c}" for c in CLASSES])

cm_norm = (cm.astype(float) / cm.sum(axis=1, keepdims=True).clip(min=1))
cm_norm_df = pd.DataFrame(cm_norm, index=[f"true:{c}" for c in CLASSES],
                                   columns=[f"pred:{c}" for c in CLASSES])

# --- Per-class ROC-AUC (one-vs-rest) ---
# Need one-hot y_true for AUC
y_true_oh = np.zeros_like(probs)
y_true_oh[np.arange(len(y_true)), y_true] = 1

# returns array of AUC per class in CLASSES order
try:
    auc_per_class = roc_auc_score(y_true_oh, probs, multi_class='ovr', average=None)
    auc_df = pd.DataFrame({'class': CLASSES, 'auc_ovr': auc_per_class})
except Exception as e:
    # e.g., if a class has no positive samples in y_true
    auc_df = pd.DataFrame({'class': CLASSES, 'auc_ovr': [np.nan]*len(CLASSES)})

# --- Nice combined table ---
metrics_df = (
    per_class_df
      .merge(auc_df, on='class', how='left')
      .sort_values('f1-score', ascending=False)
      .reset_index(drop=True)
)

# Confusion matrix (for correct counts)
cm = confusion_matrix(y_true, y_pred, labels=range(len(CLASSES)))
correct_counts = np.diag(cm)

# Add as new column to metrics_df
metrics_df['correct'] = correct_counts

# Reorder for clarity
metrics_df = metrics_df[['class', 'precision', 'recall', 'f1-score', 'support', 'correct', 'auc_ovr']]

print("\nPer-class metrics (with correct counts):")
print(metrics_df.to_string(index=False))




Overall: {'accuracy': 0.6108108108108108, 'macro_f1': 0.5414460455121072, 'weighted_f1': 0.6075090776336027}

Per-class metrics:
           class  precision   recall  f1-score  support  auc_ovr
pleural effusion   0.660377 0.700000  0.679612     50.0 0.818963
    cardiomegaly   0.569231 0.740000  0.643478     50.0 0.886370
   consolidation   1.000000 0.434783  0.606061     23.0 0.757381
     atelectasis   0.630435 0.580000  0.604167     50.0 0.806667
           edema   0.181818 0.166667  0.173913     12.0 0.750000


we only want cases with unique

In [11]:
# ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
target_class = "Pleural Effusion"

In [12]:
# positives dataset
pos_df = df_merged[df_merged["Problem_List"].apply(lambda x: target_class.lower() in x)].copy()
pos_df["imgpath"] = pos_df["filename"].map(lambda f: f if os.path.isabs(f) else os.path.join(IMG_PATH, f))
pos_df[target_class] = 1
pos_df["Normal"] = 0
print(len(pos_df))

286


In [13]:
# negatives dataset
neg_df = df_merged[df_merged['Problems'] == "normal"].copy()
neg_df["imgpath"] = neg_df["filename"].map(lambda f: f if os.path.isabs(f) else os.path.join(IMG_PATH, f))
neg_df = neg_df.sample(n=min(len(neg_df), len(pos_df)), random_state=42)
neg_df[target_class] = 0
neg_df["Normal"] = 1
print(len(neg_df))

286


In [14]:

meta_df = pd.concat([
    pos_df[['imgpath',target_class,'Normal']],
    neg_df[['imgpath',target_class,'Normal']]
], axis=0).reset_index(drop=True)

# Quick sanity: count missing files
missing = (~meta_df['imgpath'].apply(os.path.exists)).sum()
print(f"Total images: {len(meta_df)} (pos={len(pos_df)}, neg={len(neg_df)}). Missing files: {missing}")

output_path = Path(f'local_data/{target_class}-test-meta.csv')
output_path.parent.mkdir(exist_ok=True)
meta_df.to_csv(output_path)
print(f"\nSaved metadata to: {output_path}")

Total images: 572 (pos=286, neg=286). Missing files: 0

Saved metadata to: local_data/Pleural Effusion-test-meta.csv


In [15]:
# Prepare prompts: Cardiomegaly vs Normal
# We will use CheXpert cardiomegaly prompts and simple "No Finding" style for Normal

chex_prompts = generate_chexpert_class_prompts(n=7)
target_prompts = chex_prompts[target_class]

normal_prompts = neg_prompts[target_class]


cls_prompts_dict = {
    target_class: target_prompts,
    'Normal': normal_prompts,
}

print(f'{target_class} prompts:', len(target_prompts))
print('Normal prompts:', len(normal_prompts))


sample 7 num of prompts for Atelectasis from total 210
sample 7 num of prompts for Cardiomegaly from total 15
sample 7 num of prompts for Consolidation from total 192
sample 7 num of prompts for Edema from total 18
sample 7 num of prompts for Pleural Effusion from total 54
Pleural Effusion prompts: 7
Normal prompts: 3


In [None]:
# check the cardio prompts
for x in target_prompts:
    print(x)

In [None]:
# Build dataset and dataloader for cardiomegaly
class_names = [target_class, 'Normal']

# Create dataset from our metadata
dataset = ZeroShotImageDataset(
    datalist=[f'{target_class}-test'],  
    class_names=class_names
)

# Collator with prompts (binary)
collator = ZeroShotImageCollator(
    mode='binary',
    cls_prompts=cls_prompts_dict
)

batch_size = 32
loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
    num_workers=2
)

print(f"DataLoader ready: {len(loader)} batches")
# Test one batch
sample_batch = next(iter(loader))
print(sample_batch['pixel_values'].shape, sample_batch['labels'].shape)
print('Prompt inputs keys:', list(sample_batch['prompt_inputs'].keys()))


In [None]:
# Initialize model and classifier
processor = MedCLIPProcessor()
# Use resnet by default; switch to 'vit' if preferred
model = MedCLIPModel.from_pretrained(vision_model='resnet', device='mps')
clf = PromptClassifier(model, ensemble=False)
clf.to('mps')
clf.eval()
print('Model and classifier ready')


In [None]:
# Evaluate
# Binary classification (Cardiomegaly vs Normal)
evaluator = Evaluator(
    medclip_clf=clf,
    eval_dataloader=loader,
    mode='binary'
)

results = evaluator.evaluate()
print("\n" + "="*50)
print(f"EVALUATION RESULTS ({target_class} vs Normal)")
print("="*50)
for metric, value in results.items():
    if metric not in ['pred', 'labels']:
        if isinstance(value, float):
            print(f"{metric:20s}: {value:.4f}")
        else:
            print(f"{metric:20s}: {value}")


In [None]:
pred = results['pred']
pred_scores = torch.tensor(pred).sigmoid().numpy()
print('Predicted class counts:', (pred_scores.argmax(1)==0).sum(), (pred_scores.argmax(1)==1).sum())