In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from sklearn.metrics import confusion_matrix

from medclip import MedCLIPModel, constants
from medclip.dataset import ZeroShotImageDataset, ZeroShotImageCollator
from medclip.prompts import generate_chexpert_class_prompts

from medclip.prototyping import compute_embeddings_over_loader, construct_prototypes, classify_with_prototypes, tokenize_all_prompts

In [13]:
# precomputing the fused embeddings for calibration and test data (requires model forward pass)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load model and processor
model = MedCLIPModel.from_pretrained(vision_model="vit", device=device)
model.eval()

# Define disease classes
disease_classes = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]

# Generate prompts for each class
prompts_per_class = 10
cls_prompts = generate_chexpert_class_prompts(n=prompts_per_class)

tokenizer = AutoTokenizer.from_pretrained(constants.BERT_TYPE)
tokenizer.model_max_length = 77
all_tokenized = tokenize_all_prompts(tokenizer, cls_prompts, disease_classes, device)


# step 1: calibration data
calib_data_name = "nih-sampled-calib"
calib_dataset = ZeroShotImageDataset(datalist=[calib_data_name], class_names=disease_classes)
calib_collator = ZeroShotImageCollator(mode="multiclass", cls_prompts=cls_prompts)
calib_loader = DataLoader(calib_dataset, batch_size=32, collate_fn=calib_collator, shuffle=False)

calib_results = compute_embeddings_over_loader(
    model=model,
    dataloader=calib_loader,
    all_tokenized=all_tokenized,
    num_classes=len(disease_classes),
    prompts_per_class=prompts_per_class,
    device=device,
    collect_for_calibration=True,
    concat=True,
)

all_image_embeddings = calib_results["image_embeddings"]
all_fused_embeddings = calib_results["fused_embeddings"]
all_labels = calib_results["labels"]
calib_class_logits = calib_results["class_logits"]
# calib_preds = calib_results["preds"]
# calib_correct_mask = calib_results["correct_mask"]

print(f"Image embeddings shape: {all_image_embeddings.shape}")
print(f"Fused embeddings shape: {all_fused_embeddings.shape}")
print(f"Labels shape: {all_labels.shape}")
print(f"Class logits shape: {calib_class_logits.shape}")
print(f"\nCalibration Fusion complete: {all_fused_embeddings.shape[0]} samples processed")


# Step 2: Test data
print("Processing test data...")
test_dataset = ZeroShotImageDataset(datalist=["nih-sampled-test"], class_names=disease_classes)
test_collator = ZeroShotImageCollator(mode="multiclass", cls_prompts=cls_prompts)
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=test_collator, shuffle=False)

test_results = compute_embeddings_over_loader(
    model=model,
    dataloader=test_loader,
    all_tokenized=all_tokenized,
    num_classes=len(disease_classes),
    prompts_per_class=prompts_per_class,
    device=device,
    collect_for_calibration=False,
    concat=True,
)

test_fused_embeddings = test_results["fused_embeddings"]
test_labels = test_results["labels"]

print(f"Test fused embeddings shape: {test_fused_embeddings.shape}")
print(f"Test labels shape: {test_labels.shape}")
print(f"\nTest Fusion complete: {test_fused_embeddings.shape[0]} samples processed")

Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.Laye

Model moved to mps
load model weight from: pretrained/medclip-vit
sample 10 num of prompts for Atelectasis from total 210
sample 10 num of prompts for Cardiomegaly from total 15
sample 10 num of prompts for Consolidation from total 192
sample 10 num of prompts for Edema from total 18
sample 10 num of prompts for Pleural Effusion from total 54
load data from ./local_data/nih-sampled-calib-meta.csv




Image embeddings shape: torch.Size([2000, 512])
Fused embeddings shape: torch.Size([2000, 1024])
Labels shape: torch.Size([2000])
Class logits shape: torch.Size([2000, 5])

Calibration Fusion complete: 2000 samples processed
Processing test data...
load data from ./local_data/nih-sampled-test-meta.csv




Test fused embeddings shape: torch.Size([5000, 1024])
Test labels shape: torch.Size([5000])

Test Fusion complete: 5000 samples processed


In [25]:
# step 3: construct prototypes
prototypes_list = []
shots = list(range(1, 51)) + list(range(60, 201, 10))

for k in shots:
    prototypes = construct_prototypes(
    calib_results=calib_results,
    disease_classes=disease_classes,
    top_k=k,
)
    prototypes_list.append(prototypes)

for k, prototypes in zip(shots, prototypes_list):
    predictions, accuracy, per_class_acc, cm = classify_with_prototypes(test_fused_embeddings, test_labels, prototypes, disease_classes)
    print(f"Top {k} shots:")
    print(f"\nPrototype-based classifier accuracy: {accuracy:.4f}")
    print("Per-class accuracy:")
    for name, acc in per_class_acc.items():
        print(f"  {name:20s}: {acc:.4f}")

Top 1 shots:

Prototype-based classifier accuracy: 0.5322
Per-class accuracy:
  Atelectasis         : 0.6090
  Cardiomegaly        : 0.5440
  Consolidation       : 0.1240
  Edema               : 0.7050
  Pleural Effusion    : 0.6790
Top 2 shots:

Prototype-based classifier accuracy: 0.5416
Per-class accuracy:
  Atelectasis         : 0.5920
  Cardiomegaly        : 0.6420
  Consolidation       : 0.2080
  Edema               : 0.6120
  Pleural Effusion    : 0.6540
Top 3 shots:

Prototype-based classifier accuracy: 0.5328
Per-class accuracy:
  Atelectasis         : 0.4820
  Cardiomegaly        : 0.6910
  Consolidation       : 0.2380
  Edema               : 0.5870
  Pleural Effusion    : 0.6660
Top 4 shots:

Prototype-based classifier accuracy: 0.5328
Per-class accuracy:
  Atelectasis         : 0.4880
  Cardiomegaly        : 0.6850
  Consolidation       : 0.2310
  Edema               : 0.6050
  Pleural Effusion    : 0.6550
Top 5 shots:

Prototype-based classifier accuracy: 0.5342
Per-class 