In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

from medclip import MedCLIPModel, MedCLIPProcessor, PromptClassifier
from medclip.prompts import generate_tb_class_prompts
from medclip.dataset import ZeroShotImageDataset, ZeroShotImageCollator
from medclip.evaluator import Evaluator



In [3]:
tb_dir = Path('data/tuberculosis/TB_Chest_Radiography_Database/')

tb_imgs = list((tb_dir / 'Tuberculosis').glob('*.png'))
normal_imgs = list((tb_dir / 'Normal').glob('*.png'))

print(f"Found {len(tb_imgs)} TB images")
print(f"Found {len(normal_imgs)} Normal images")

num_sample_per_class = 700
np.random.seed(42)

# Sample images
tb_imgs_sampled = np.random.choice(tb_imgs, size=num_sample_per_class, replace=False)
normal_imgs_sampled = np.random.choice(normal_imgs, size=num_sample_per_class, replace=False)

print(f"\nSampled {len(tb_imgs_sampled)} TB images")
print(f"Sampled {len(normal_imgs_sampled)} Normal images")

# Create DataFrame
data = []
for img_path in tb_imgs_sampled:
    data.append({
        'imgpath': str(img_path),
        'TB': 1,
        'Normal': 0,
    })

for img_path in normal_imgs_sampled:
    data.append({
        'imgpath': str(img_path),
        'TB': 0,
        'Normal': 1,
    })

df = pd.DataFrame(data)

# Save sampled csv
output_path = Path('local_data/tb-test-meta.csv')
output_path.parent.mkdir(exist_ok=True)
df.to_csv(output_path)
print(f"\nSaved metadata to: {output_path}")

Found 700 TB images
Found 3500 Normal images

Sampled 700 TB images
Sampled 700 Normal images

Saved metadata to: local_data/tb-test-meta.csv


In [4]:
class_names = ['TB', 'Normal']

dataset = ZeroShotImageDataset(
    datalist=['tb-test'],  # will load from local_data/tb-test-meta.csv
    class_names=class_names
)

print(f"Dataset size: {len(dataset)}")
print(f"Class names: {class_names}")

# Check a sample
img, label = dataset[0]
print(f"\nSample image shape: {img.shape}")
print(f"Sample label:")
print(label)

load data from ./local_data/tb-test-meta.csv
Dataset size: 1400
Class names: ['TB', 'Normal']

Sample image shape: torch.Size([1, 1, 224, 224])
Sample label:
  TB Normal
0  1      0


In [25]:
# Generate TB class prompts
tb_prompts = generate_tb_class_prompts(n=10)
print(f"\nGenerated prompts for classes: {list(tb_prompts.keys())}")
print(f"Number of prompts per class: {[len(v) for v in tb_prompts.values()]}")

print(f"\nAvailable TB prompts:")
for cls, prompts_list in tb_prompts.items():
    print(f"  {cls}: {len(prompts_list)} prompts")
    print(f"    Examples: {prompts_list[:2]}")

sample 10 num of prompts for Tuberculosis from total 480

Generated prompts for classes: ['Tuberculosis']
Number of prompts per class: [10]

Available TB prompts:
  Tuberculosis: 10 prompts
    Examples: ['reticulonodular lung lesion with volume loss in the upper lung zones', 'reticulonodular lung opacity with volume loss in the upper lung zones']


In [26]:
tb_prompts['Tuberculosis'] = [f'this x-ray image describes {prompt}' for prompt in tb_prompts['Tuberculosis']]

In [28]:
cls_prompts_dict = {
    'TB': tb_prompts['Tuberculosis'],
    'Normal': [
        'no findings',
        'no evidence of pneumonia',
        'normal chest x-ray',
        'clear lungs',
        'no acute disease',
        'no radiographic abnormality',
        'healthy chest radiograph',
        'unremarkable chest x-ray',
        'no pathological findings',
        'normal cardiomediastinal silhouette',
        # 'no infiltrates',
        # 'normal pulmonary vasculature'
    ]
}

cls_prompts_dict['Normal'] = [f'this x-ray image describes {prompt}' for prompt in cls_prompts_dict['Normal']]

print(f"Class prompts prepared:")
for cls, prompts in cls_prompts_dict.items():
    print(f"  {cls}: {len(prompts)} prompts")

Class prompts prepared:
  TB: 10 prompts
  Normal: 10 prompts


In [29]:
# Create collator with prompts
collator = ZeroShotImageCollator(
    mode='binary',  # binary classification: TB vs Normal
    cls_prompts=cls_prompts_dict
)

# Create DataLoader
batch_size = 32
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collator,
    num_workers=2
)

print(f"DataLoader created with batch_size={batch_size}")
print(f"Number of batches: {len(dataloader)}")

# Test the dataloader
sample_batch = next(iter(dataloader))
print(f"\nSample batch keys: {sample_batch.keys()}")
print(f"Pixel values shape: {sample_batch['pixel_values'].shape}")
print(f"Labels shape: {sample_batch['labels'].shape}")
print(f"Prompt inputs keys: {sample_batch['prompt_inputs'].keys()}")



DataLoader created with batch_size=32
Number of batches: 44
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Sample batch keys: dict_keys(['pixel_values', 'prompt_inputs', 'labels'])
Pixel values shape: torch.Size([32, 3, 224, 224])
Labels shape: torch.Size([32])
Prompt inputs keys: dict_keys(['TB', 'Normal'])


In [30]:
processor = MedCLIPProcessor()
model = MedCLIPModel.from_pretrained(vision_model='resnet', device='mps')
clf = PromptClassifier(model, ensemble=True)
clf.to('mps')
clf.eval()

print("Model and classifier ready for evaluation")

  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  state_dict = torch.load(os.path.join(input_dir, constants.WEIGHTS_NAME), map_locat

Model moved to mps
load model weight from: pretrained/medclip-resnet
Model and classifier ready for evaluation


In [31]:
evaluator = Evaluator(
    medclip_clf=clf,
    eval_dataloader=dataloader,
    mode='binary'
)

print("Evaluator created, starting evaluation...")

Evaluator created, starting evaluation...


In [32]:
# Run evaluation
results = evaluator.evaluate()

# Display results
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
for metric, value in results.items():
    if metric not in ['pred', 'labels']:
        print(f"{metric:20s}: {value:.4f}")
print("="*50)


Evaluation:   0%|          | 0/44 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Evaluation: 100%|██████████| 44/44 [00:21<00:00,  2.05it/s]


EVALUATION RESULTS
acc                 : 0.5493
precision           : 0.6414
recall              : 0.5493
f1-score            : 0.4616



