# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [2]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# # Padchest
# cxr_filepath: str = '/home/ec2-user/all_raw_data/padchest/images/44_cxr.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/padchest/44_cxr_labels.csv' # (optional for evaluation) if labels are provided, provide path

# CheXzero test
cxr_filepath: str = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test_labels_view1.csv' # (optional for evaluation) if labels are provided, provide path

# # CheXzero val
# cxr_filepath: str = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid/chexpert_val.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid_view1.csv' # (optional for evaluation) if labels are provided, provide path


model_dir: str = '../checkpoints/cxr-bert' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
# cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
#                                       'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
#                                       'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
#                                       'Pneumothorax', 'Support Devices']
cxr_labels = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
                'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other',
                'Fracture', 'Support Devices']
# cxr_labels = [label.lower() for label in cxr_labels]

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)

n = len(model_paths)

model_paths = model_paths[-3:]
print(model_paths)

['../checkpoints/cxr-bert/checkpoint_13000.pt', '../checkpoints/cxr-bert/checkpoint_14000.pt', '../checkpoints/cxr-bert/checkpoint.pt']


## Run Inference

In [3]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
    change_text_encoder: bool = False,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
            change_text_encoder=change_text_encoder,
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [4]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
    change_text_encoder=True, # use CXR BERT
)



../checkpoints/cxr-bert/checkpoint.pt using an online model  
Inferring model ../checkpoints/cxr-bert/checkpoint.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

../checkpoints/cxr-bert/checkpoint_13000.pt using an online model  
Inferring model ../checkpoints/cxr-bert/checkpoint_13000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

../checkpoints/cxr-bert/checkpoint_14000.pt using an online model  
Inferring model ../checkpoints/cxr-bert/checkpoint_14000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

In [5]:
# save averaged preds
pred_name = "chexpert_preds_bert.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [6]:
# make test_true
test_pred = y_pred_avg
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [7]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,No Finding_auc,Enlarged Cardiomediastinum_auc,Cardiomegaly_auc,Lung Opacity_auc,Lung Lesion_auc,Edema_auc,Consolidation_auc,Pneumonia_auc,Atelectasis_auc,Pneumothorax_auc,Pleural Effusion_auc,Pleural Other_auc,Fracture_auc,Support Devices_auc
mean,0.3805,0.5984,0.3303,0.7954,0.7021,0.5963,0.7881,0.7834,0.4515,0.4986,0.3689,0.3499,0.4015,0.4003
lower,0.3098,0.5496,0.2806,0.7564,0.4924,0.5361,0.6951,0.6608,0.3923,0.3582,0.3009,0.2046,0.1562,0.3507
upper,0.4539,0.6492,0.3819,0.8317,0.883,0.6519,0.8678,0.8831,0.5074,0.6381,0.4312,0.493,0.7249,0.4487


In [8]:
print("Mean AUC: {}".format(np.mean(bootstrap_results[1].iloc[0, :])))

Mean AUC: 0.5317999999999999


In [9]:
pd.DataFrame(bootstrap_results[1]).to_csv('chexzero_chexpert_bert.csv')