In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

from medclip import MedCLIPModel, MedCLIPProcessor, PromptClassifier
from medclip.prompts import process_class_prompts, generate_chexpert_class_prompts
from medclip.dataset import ZeroShotImageDataset, ZeroShotImageCollator
from medclip.evaluator import Evaluator

In [27]:
class_names = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']

# Create dataset
dataset = ZeroShotImageDataset(
    datalist=['nih-sampled'],  # will load from local_data/nih-sampled-meta.csv
    class_names=class_names
)

print(f"Dataset size: {len(dataset)}")
print(f"Class names: {class_names}")

# Check a sample
img, label = dataset[0]
print(f"\nSample image shape: {img.shape}")
print(f"Sample label:")
print(label)

load data from ./local_data/nih-sampled-meta.csv
Dataset size: 7000
Class names: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']

Sample image shape: torch.Size([1, 1, 224, 224])
Sample label:
  Atelectasis Cardiomegaly Consolidation Edema Pleural Effusion
0           0            0             0     1                0


In [28]:
# Generate 5 diagnoses class prompts
chexpert_classes = generate_chexpert_class_prompts(n=10)
print(f"\nGenerated prompts for classes: {list(chexpert_classes.keys())}")
print(f"Number of prompts per class: {[len(v) for v in chexpert_classes.values()]}")

for cls, prompts_list in chexpert_classes.items():
    print(f"  {cls}: {len(prompts_list)} prompts")
    print(f"    Examples: {prompts_list[:2]}")

sample 10 num of prompts for Atelectasis from total 210
sample 10 num of prompts for Cardiomegaly from total 15
sample 10 num of prompts for Consolidation from total 192
sample 10 num of prompts for Edema from total 18
sample 10 num of prompts for Pleural Effusion from total 54

Generated prompts for classes: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']
Number of prompts per class: [10, 10, 10, 10, 10]
  Atelectasis: 10 prompts
    Examples: ['mild linear atelectasis at the left lung zone', 'minimal trace atelectasis at the bilateral lung bases']
  Cardiomegaly: 10 prompts
    Examples: [' cardiac silhouette size is upper limits of normal ', ' cardiomegaly which is unchanged ']
  Consolidation: 10 prompts
    Examples: ['increased airspace consolidation at the right uppper lobe', 'apperance of retrocardiac consolidation at the left lower lobe']
  Edema: 10 prompts
    Examples: ['decreased pulmonary edema ', 'presistent pulmonary edema ']
  Pleural Effu

In [29]:
# Create collator with prompts
collator = ZeroShotImageCollator(
    mode='multiclass',
    cls_prompts=chexpert_classes
)

# Create DataLoader
batch_size = 32
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
    num_workers=2
)

print(f"DataLoader created with batch_size={batch_size}")
print(f"Number of batches: {len(dataloader)}")

# Test the dataloader
sample_batch = next(iter(dataloader))
print(f"\nSample batch keys: {sample_batch.keys()}")
print(f"Pixel values shape: {sample_batch['pixel_values'].shape}")
print(f"Labels shape: {sample_batch['labels'].shape}")
print(f"Prompt inputs keys: {sample_batch['prompt_inputs'].keys()}")



DataLoader created with batch_size=32
Number of batches: 219
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Sample batch keys: dict_keys(['pixel_values', 'prompt_inputs', 'labels'])
Pixel values shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32])
Prompt inputs keys: dict_keys(['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion'])


In [30]:
# Initialize model and classifier (reusing from earlier cells)
processor = MedCLIPProcessor()
model = MedCLIPModel.from_pretrained(vision_model='vit', device='mps')
clf = PromptClassifier(model, ensemble=True)
clf.to('mps')
clf.eval()

print("Model and classifier ready for evaluation")

Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weig

Model moved to mps
load model weight from: pretrained/medclip-vit
Model and classifier ready for evaluation


In [31]:
# Create evaluator
evaluator = Evaluator(
    medclip_clf=clf,
    eval_dataloader=dataloader,
    mode='multiclass'
)

print("Evaluator created, starting evaluation...")

Evaluator created, starting evaluation...


In [32]:
# Run evaluation
results = evaluator.evaluate()

# Display results
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
for metric, value in results.items():
    if metric not in ['pred', 'labels']:
        print(f"{metric:20s}: {value:.4f}")
print("="*50)


Evaluation:   0%|          | 0/219 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Evaluation: 100%|██████████| 219/219 [05:25<00:00,  1.49s/it]


EVALUATION RESULTS
acc                 : 0.5359
precision           : 0.5432
recall              : 0.5359
f1-score            : 0.5198





In [23]:
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,...,Unnamed: 11,diseases,disease_count,disease,imgpath,Edema,Atelectasis,Pleural Effusion,Cardiomegaly,Consolidation
0,00022470_009.png,Edema|Nodule,9,22470,46,M,AP,3056,2544,0.139,...,,['Edema'],1,Edema,data/nih/images_010/images/00022470_009.png,1,0,0,0,0
1,00004858_056.png,Atelectasis,56,4858,45,F,PA,2992,2991,0.143,...,,['Atelectasis'],1,Atelectasis,data/nih/images_003/images/00004858_056.png,0,1,0,0,0
2,00014626_023.png,Effusion|Infiltration,23,14626,44,F,AP,2692,2544,0.139,...,,['Effusion'],1,Pleural Effusion,data/nih/images_007/images/00014626_023.png,0,0,1,0,0
3,00016414_002.png,Cardiomegaly,2,16414,39,M,PA,2704,2781,0.143,...,,['Cardiomegaly'],1,Cardiomegaly,data/nih/images_008/images/00016414_002.png,0,0,0,1,0
4,00019805_005.png,Consolidation|Infiltration|Mass|Nodule,5,19805,29,M,AP,3056,2544,0.139,...,,['Consolidation'],1,Consolidation,data/nih/images_009/images/00019805_005.png,0,0,0,0,1
