# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [2]:
# check CheXpert-v1.0-small/valid.csv's row
import pandas as pd

# CSV 파일 로드
train_csv_path = "../data/CheXpert-v1.0-small/valid.csv"  # 실제 파일 경로로 변경
df = pd.read_csv(train_csv_path)

# 행 개수 출력
print(f"총 샘플 개수 (행 개수): {len(df)}")

총 샘플 개수 (행 개수): 234


In [3]:
# check CheXpert/test_labels.csv's row
import pandas as pd

# CSV 파일 로드
train_csv_path = "../data/CheXpert/test_labels.csv"  # 실제 파일 경로로 변경
df = pd.read_csv(train_csv_path)

# 행 개수 출력
print(f"총 샘플 개수 (행 개수): {len(df)}")

총 샘플 개수 (행 개수): 668


In [4]:
# check CheXpert/val_labels.csv's row
import pandas as pd

# CSV 파일 로드
train_csv_path = "../data/CheXpert/val_labels.csv"  # 실제 파일 경로로 변경
df = pd.read_csv(train_csv_path)

# 행 개수 출력
print(f"총 샘플 개수 (행 개수): {len(df)}")

총 샘플 개수 (행 개수): 234


In [5]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
cxr_filepath: str = '../data/chexpert_test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '../data/CheXpert/test_labels.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = '../checkpoints/pt-imp-adam' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                                      'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                                      'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                                      'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)
        
print(model_paths)

['../checkpoints/pt-imp-adam2/checkpoint_125000.pt', '../checkpoints/pt-imp-adam2/checkpoint_140000.pt', '../checkpoints/pt-imp-adam2/checkpoint_80000.pt', '../checkpoints/pt-imp-adam2/checkpoint_65000.pt', '../checkpoints/pt-imp-adam2/checkpoint_150000.pt', '../checkpoints/pt-imp-adam2/checkpoint_75000.pt', '../checkpoints/pt-imp-adam2/checkpoint_170000.pt', '../checkpoints/pt-imp-adam2/checkpoint_95000.pt', '../checkpoints/pt-imp-adam2/checkpoint_155000.pt', '../checkpoints/pt-imp-adam2/checkpoint_70000.pt', '../checkpoints/pt-imp-adam2/checkpoint_40000.pt', '../checkpoints/pt-imp-adam2/checkpoint_85000.pt', '../checkpoints/pt-imp-adam2/checkpoint_105000.pt', '../checkpoints/pt-imp-adam2/checkpoint_30000.pt', '../checkpoints/pt-imp-adam2/checkpoint_145000.pt', '../checkpoints/pt-imp-adam2/checkpoint_55000.pt', '../checkpoints/pt-imp-adam2/checkpoint_100000.pt', '../checkpoints/pt-imp-adam2/checkpoint.pt', '../checkpoints/pt-imp-adam2/checkpoint_20000.pt', '../checkpoints/pt-imp-adam2

In [6]:
import h5py

# 파일 열기
cxr_filepath = "../data/chexpert_test.h5"  # 실제 파일 경로로 변경
with h5py.File(cxr_filepath, "r") as f:
    # 데이터셋 확인
    print(list(f.keys()))  # 파일 내 데이터셋 이름 확인
    dataset_name = list(f.keys())[0]  # 첫 번째 데이터셋 선택
    
    # 데이터 크기 출력
    print(f"Dataset '{dataset_name}' has length:", len(f[dataset_name]))

['cxr']
Dataset 'cxr' has length: 500


## Run Inference

In [7]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [8]:
predictions, y_pred_avg = ensemble_models(
    model_paths=[model_paths[5]], 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=None,
)

Inferring model ../checkpoints/pt-imp-adam2/checkpoint_75000.pt


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]



features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

features mean:  tensor(0.0007)
features var:  tensor(6.4138e-17)


In [9]:
# # save averaged preds
# pred_name = "chexpert_preds.npy" # add name of preds
# predictions_dir = predictions_dir / pred_name
# np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [10]:
print(len(y_pred_avg))
print(len(predictions))

500
1


In [11]:
# make test_true
test_pred = y_pred_avg
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
# test_true = test_true[:len(test_pred)]
# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [12]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,Atelectasis_auc,Cardiomegaly_auc,Consolidation_auc,Edema_auc,Enlarged Cardiomediastinum_auc,Fracture_auc,Lung Lesion_auc,Lung Opacity_auc,No Finding_auc,Pleural Effusion_auc,Pleural Other_auc,Pneumonia_auc,Pneumothorax_auc,Support Devices_auc
mean,0.4951,0.5427,0.3696,0.5326,0.4755,0.4892,0.5578,0.3987,0.5251,0.469,0.7758,0.5695,0.4931,0.5259
lower,0.4361,0.4974,0.2865,0.4691,0.4277,0.1807,0.3345,0.3487,0.4595,0.4171,0.6265,0.4176,0.2976,0.4773
upper,0.5487,0.5886,0.4539,0.5983,0.5238,0.9078,0.7802,0.4479,0.5928,0.5257,0.993,0.7127,0.6752,0.576


In [13]:
df = bootstrap_results[0]

print(type(df))

print(df.columns)

df = df[['Atelectasis_auc','Cardiomegaly_auc','Consolidation_auc','Edema_auc','Pleural Effusion_auc']]

df.describe().T.mean()

<class 'pandas.core.frame.DataFrame'>
Index(['Atelectasis_auc', 'Cardiomegaly_auc', 'Consolidation_auc', 'Edema_auc',
       'Enlarged Cardiomediastinum_auc', 'Fracture_auc', 'Lung Lesion_auc',
       'Lung Opacity_auc', 'No Finding_auc', 'Pleural Effusion_auc',
       'Pleural Other_auc', 'Pneumonia_auc', 'Pneumothorax_auc',
       'Support Devices_auc'],
      dtype='object')


count    1000.000000
mean        0.481777
std         0.031065
min         0.385397
25%         0.460893
50%         0.481367
75%         0.502181
max         0.596797
dtype: float64

In [14]:
df_2 = bootstrap_results[0]
df_2.describe().T.mean()

count    998.642857
mean       0.515681
std        0.058571
min        0.344134
25%        0.477015
50%        0.515603
75%        0.553076
max        0.699771
dtype: float64