# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [2]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# # Padchest
# cxr_filepath: str = '/home/ec2-user/all_raw_data/padchest/images/44_cxr.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/padchest/44_cxr_labels.csv' # (optional for evaluation) if labels are provided, provide path

# CheXzero test
cxr_filepath: str = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test_labels_view1.csv' # (optional for evaluation) if labels are provided, provide path

# # CheXzero val
# cxr_filepath: str = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid/chexpert_val.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid_view1.csv' # (optional for evaluation) if labels are provided, provide path


model_dir: str = '../checkpoints/cxr-bert' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
# cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
#                                       'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
#                                       'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
#                                       'Pneumothorax', 'Support Devices']
cxr_labels = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
                'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other',
                'Fracture', 'Support Devices']
# cxr_labels = [label.lower() for label in cxr_labels]

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)

n = len(model_paths)

model_paths = model_paths[-3:]
print(model_paths)

['../checkpoints/cxr-bert/checkpoint_42000.pt', '../checkpoints/cxr-bert/checkpoint_45000.pt', '../checkpoints/cxr-bert/checkpoint.pt']


## Run Inference

In [3]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
    change_text_encoder: bool = False,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
            change_text_encoder=change_text_encoder,
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [4]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
    change_text_encoder=True, # don't use CXR BERT
)



../checkpoints/cxr-bert/checkpoint.pt using an online model  
Argument error. Set pretrained = True. <class 'RuntimeError'>


RuntimeError: Error(s) in loading state_dict for CLIP:
	Unexpected key(s) in state_dict: "encode_text.bert.embeddings.position_ids", "encode_text.bert.embeddings.word_embeddings.weight", "encode_text.bert.embeddings.position_embeddings.weight", "encode_text.bert.embeddings.token_type_embeddings.weight", "encode_text.bert.embeddings.LayerNorm.weight", "encode_text.bert.embeddings.LayerNorm.bias", "encode_text.bert.encoder.layer.0.attention.self.query.weight", "encode_text.bert.encoder.layer.0.attention.self.query.bias", "encode_text.bert.encoder.layer.0.attention.self.key.weight", "encode_text.bert.encoder.layer.0.attention.self.key.bias", "encode_text.bert.encoder.layer.0.attention.self.value.weight", "encode_text.bert.encoder.layer.0.attention.self.value.bias", "encode_text.bert.encoder.layer.0.attention.output.dense.weight", "encode_text.bert.encoder.layer.0.attention.output.dense.bias", "encode_text.bert.encoder.layer.0.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.0.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.0.intermediate.dense.weight", "encode_text.bert.encoder.layer.0.intermediate.dense.bias", "encode_text.bert.encoder.layer.0.output.dense.weight", "encode_text.bert.encoder.layer.0.output.dense.bias", "encode_text.bert.encoder.layer.0.output.LayerNorm.weight", "encode_text.bert.encoder.layer.0.output.LayerNorm.bias", "encode_text.bert.encoder.layer.1.attention.self.query.weight", "encode_text.bert.encoder.layer.1.attention.self.query.bias", "encode_text.bert.encoder.layer.1.attention.self.key.weight", "encode_text.bert.encoder.layer.1.attention.self.key.bias", "encode_text.bert.encoder.layer.1.attention.self.value.weight", "encode_text.bert.encoder.layer.1.attention.self.value.bias", "encode_text.bert.encoder.layer.1.attention.output.dense.weight", "encode_text.bert.encoder.layer.1.attention.output.dense.bias", "encode_text.bert.encoder.layer.1.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.1.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.1.intermediate.dense.weight", "encode_text.bert.encoder.layer.1.intermediate.dense.bias", "encode_text.bert.encoder.layer.1.output.dense.weight", "encode_text.bert.encoder.layer.1.output.dense.bias", "encode_text.bert.encoder.layer.1.output.LayerNorm.weight", "encode_text.bert.encoder.layer.1.output.LayerNorm.bias", "encode_text.bert.encoder.layer.2.attention.self.query.weight", "encode_text.bert.encoder.layer.2.attention.self.query.bias", "encode_text.bert.encoder.layer.2.attention.self.key.weight", "encode_text.bert.encoder.layer.2.attention.self.key.bias", "encode_text.bert.encoder.layer.2.attention.self.value.weight", "encode_text.bert.encoder.layer.2.attention.self.value.bias", "encode_text.bert.encoder.layer.2.attention.output.dense.weight", "encode_text.bert.encoder.layer.2.attention.output.dense.bias", "encode_text.bert.encoder.layer.2.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.2.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.2.intermediate.dense.weight", "encode_text.bert.encoder.layer.2.intermediate.dense.bias", "encode_text.bert.encoder.layer.2.output.dense.weight", "encode_text.bert.encoder.layer.2.output.dense.bias", "encode_text.bert.encoder.layer.2.output.LayerNorm.weight", "encode_text.bert.encoder.layer.2.output.LayerNorm.bias", "encode_text.bert.encoder.layer.3.attention.self.query.weight", "encode_text.bert.encoder.layer.3.attention.self.query.bias", "encode_text.bert.encoder.layer.3.attention.self.key.weight", "encode_text.bert.encoder.layer.3.attention.self.key.bias", "encode_text.bert.encoder.layer.3.attention.self.value.weight", "encode_text.bert.encoder.layer.3.attention.self.value.bias", "encode_text.bert.encoder.layer.3.attention.output.dense.weight", "encode_text.bert.encoder.layer.3.attention.output.dense.bias", "encode_text.bert.encoder.layer.3.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.3.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.3.intermediate.dense.weight", "encode_text.bert.encoder.layer.3.intermediate.dense.bias", "encode_text.bert.encoder.layer.3.output.dense.weight", "encode_text.bert.encoder.layer.3.output.dense.bias", "encode_text.bert.encoder.layer.3.output.LayerNorm.weight", "encode_text.bert.encoder.layer.3.output.LayerNorm.bias", "encode_text.bert.encoder.layer.4.attention.self.query.weight", "encode_text.bert.encoder.layer.4.attention.self.query.bias", "encode_text.bert.encoder.layer.4.attention.self.key.weight", "encode_text.bert.encoder.layer.4.attention.self.key.bias", "encode_text.bert.encoder.layer.4.attention.self.value.weight", "encode_text.bert.encoder.layer.4.attention.self.value.bias", "encode_text.bert.encoder.layer.4.attention.output.dense.weight", "encode_text.bert.encoder.layer.4.attention.output.dense.bias", "encode_text.bert.encoder.layer.4.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.4.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.4.intermediate.dense.weight", "encode_text.bert.encoder.layer.4.intermediate.dense.bias", "encode_text.bert.encoder.layer.4.output.dense.weight", "encode_text.bert.encoder.layer.4.output.dense.bias", "encode_text.bert.encoder.layer.4.output.LayerNorm.weight", "encode_text.bert.encoder.layer.4.output.LayerNorm.bias", "encode_text.bert.encoder.layer.5.attention.self.query.weight", "encode_text.bert.encoder.layer.5.attention.self.query.bias", "encode_text.bert.encoder.layer.5.attention.self.key.weight", "encode_text.bert.encoder.layer.5.attention.self.key.bias", "encode_text.bert.encoder.layer.5.attention.self.value.weight", "encode_text.bert.encoder.layer.5.attention.self.value.bias", "encode_text.bert.encoder.layer.5.attention.output.dense.weight", "encode_text.bert.encoder.layer.5.attention.output.dense.bias", "encode_text.bert.encoder.layer.5.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.5.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.5.intermediate.dense.weight", "encode_text.bert.encoder.layer.5.intermediate.dense.bias", "encode_text.bert.encoder.layer.5.output.dense.weight", "encode_text.bert.encoder.layer.5.output.dense.bias", "encode_text.bert.encoder.layer.5.output.LayerNorm.weight", "encode_text.bert.encoder.layer.5.output.LayerNorm.bias", "encode_text.bert.encoder.layer.6.attention.self.query.weight", "encode_text.bert.encoder.layer.6.attention.self.query.bias", "encode_text.bert.encoder.layer.6.attention.self.key.weight", "encode_text.bert.encoder.layer.6.attention.self.key.bias", "encode_text.bert.encoder.layer.6.attention.self.value.weight", "encode_text.bert.encoder.layer.6.attention.self.value.bias", "encode_text.bert.encoder.layer.6.attention.output.dense.weight", "encode_text.bert.encoder.layer.6.attention.output.dense.bias", "encode_text.bert.encoder.layer.6.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.6.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.6.intermediate.dense.weight", "encode_text.bert.encoder.layer.6.intermediate.dense.bias", "encode_text.bert.encoder.layer.6.output.dense.weight", "encode_text.bert.encoder.layer.6.output.dense.bias", "encode_text.bert.encoder.layer.6.output.LayerNorm.weight", "encode_text.bert.encoder.layer.6.output.LayerNorm.bias", "encode_text.bert.encoder.layer.7.attention.self.query.weight", "encode_text.bert.encoder.layer.7.attention.self.query.bias", "encode_text.bert.encoder.layer.7.attention.self.key.weight", "encode_text.bert.encoder.layer.7.attention.self.key.bias", "encode_text.bert.encoder.layer.7.attention.self.value.weight", "encode_text.bert.encoder.layer.7.attention.self.value.bias", "encode_text.bert.encoder.layer.7.attention.output.dense.weight", "encode_text.bert.encoder.layer.7.attention.output.dense.bias", "encode_text.bert.encoder.layer.7.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.7.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.7.intermediate.dense.weight", "encode_text.bert.encoder.layer.7.intermediate.dense.bias", "encode_text.bert.encoder.layer.7.output.dense.weight", "encode_text.bert.encoder.layer.7.output.dense.bias", "encode_text.bert.encoder.layer.7.output.LayerNorm.weight", "encode_text.bert.encoder.layer.7.output.LayerNorm.bias", "encode_text.bert.encoder.layer.8.attention.self.query.weight", "encode_text.bert.encoder.layer.8.attention.self.query.bias", "encode_text.bert.encoder.layer.8.attention.self.key.weight", "encode_text.bert.encoder.layer.8.attention.self.key.bias", "encode_text.bert.encoder.layer.8.attention.self.value.weight", "encode_text.bert.encoder.layer.8.attention.self.value.bias", "encode_text.bert.encoder.layer.8.attention.output.dense.weight", "encode_text.bert.encoder.layer.8.attention.output.dense.bias", "encode_text.bert.encoder.layer.8.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.8.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.8.intermediate.dense.weight", "encode_text.bert.encoder.layer.8.intermediate.dense.bias", "encode_text.bert.encoder.layer.8.output.dense.weight", "encode_text.bert.encoder.layer.8.output.dense.bias", "encode_text.bert.encoder.layer.8.output.LayerNorm.weight", "encode_text.bert.encoder.layer.8.output.LayerNorm.bias", "encode_text.bert.encoder.layer.9.attention.self.query.weight", "encode_text.bert.encoder.layer.9.attention.self.query.bias", "encode_text.bert.encoder.layer.9.attention.self.key.weight", "encode_text.bert.encoder.layer.9.attention.self.key.bias", "encode_text.bert.encoder.layer.9.attention.self.value.weight", "encode_text.bert.encoder.layer.9.attention.self.value.bias", "encode_text.bert.encoder.layer.9.attention.output.dense.weight", "encode_text.bert.encoder.layer.9.attention.output.dense.bias", "encode_text.bert.encoder.layer.9.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.9.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.9.intermediate.dense.weight", "encode_text.bert.encoder.layer.9.intermediate.dense.bias", "encode_text.bert.encoder.layer.9.output.dense.weight", "encode_text.bert.encoder.layer.9.output.dense.bias", "encode_text.bert.encoder.layer.9.output.LayerNorm.weight", "encode_text.bert.encoder.layer.9.output.LayerNorm.bias", "encode_text.bert.encoder.layer.10.attention.self.query.weight", "encode_text.bert.encoder.layer.10.attention.self.query.bias", "encode_text.bert.encoder.layer.10.attention.self.key.weight", "encode_text.bert.encoder.layer.10.attention.self.key.bias", "encode_text.bert.encoder.layer.10.attention.self.value.weight", "encode_text.bert.encoder.layer.10.attention.self.value.bias", "encode_text.bert.encoder.layer.10.attention.output.dense.weight", "encode_text.bert.encoder.layer.10.attention.output.dense.bias", "encode_text.bert.encoder.layer.10.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.10.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.10.intermediate.dense.weight", "encode_text.bert.encoder.layer.10.intermediate.dense.bias", "encode_text.bert.encoder.layer.10.output.dense.weight", "encode_text.bert.encoder.layer.10.output.dense.bias", "encode_text.bert.encoder.layer.10.output.LayerNorm.weight", "encode_text.bert.encoder.layer.10.output.LayerNorm.bias", "encode_text.bert.encoder.layer.11.attention.self.query.weight", "encode_text.bert.encoder.layer.11.attention.self.query.bias", "encode_text.bert.encoder.layer.11.attention.self.key.weight", "encode_text.bert.encoder.layer.11.attention.self.key.bias", "encode_text.bert.encoder.layer.11.attention.self.value.weight", "encode_text.bert.encoder.layer.11.attention.self.value.bias", "encode_text.bert.encoder.layer.11.attention.output.dense.weight", "encode_text.bert.encoder.layer.11.attention.output.dense.bias", "encode_text.bert.encoder.layer.11.attention.output.LayerNorm.weight", "encode_text.bert.encoder.layer.11.attention.output.LayerNorm.bias", "encode_text.bert.encoder.layer.11.intermediate.dense.weight", "encode_text.bert.encoder.layer.11.intermediate.dense.bias", "encode_text.bert.encoder.layer.11.output.dense.weight", "encode_text.bert.encoder.layer.11.output.dense.bias", "encode_text.bert.encoder.layer.11.output.LayerNorm.weight", "encode_text.bert.encoder.layer.11.output.LayerNorm.bias", "encode_text.cls.predictions.bias", "encode_text.cls.predictions.transform.dense.weight", "encode_text.cls.predictions.transform.dense.bias", "encode_text.cls.predictions.transform.LayerNorm.weight", "encode_text.cls.predictions.transform.LayerNorm.bias", "encode_text.cls.predictions.decoder.weight", "encode_text.cls.predictions.decoder.bias", "encode_text.cls_projection_head.dense_to_hidden.weight", "encode_text.cls_projection_head.dense_to_hidden.bias", "encode_text.cls_projection_head.LayerNorm.weight", "encode_text.cls_projection_head.LayerNorm.bias", "encode_text.cls_projection_head.dense_to_output.weight", "encode_text.cls_projection_head.dense_to_output.bias". 

In [None]:
# save averaged preds
pred_name = "chexpert_preds_bert.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [None]:
# make test_true
test_pred = y_pred_avg
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,No Finding_auc,Enlarged Cardiomediastinum_auc,Cardiomegaly_auc,Lung Opacity_auc,Lung Lesion_auc,Edema_auc,Consolidation_auc,Pneumonia_auc,Atelectasis_auc,Pneumothorax_auc,Pleural Effusion_auc,Pleural Other_auc,Fracture_auc,Support Devices_auc
mean,0.1951,0.8855,0.9121,0.9202,0.7636,0.8971,0.8793,0.7792,0.7744,0.6676,0.9284,0.5975,0.4756,0.7099
lower,0.1369,0.8564,0.8823,0.8945,0.6056,0.864,0.8032,0.5734,0.7315,0.5041,0.9012,0.4356,0.2319,0.6621
upper,0.2594,0.9124,0.9355,0.9415,0.9133,0.9257,0.9425,0.9528,0.812,0.8394,0.9505,0.8438,0.7944,0.753


In [None]:
print("Mean AUC: {}".format(np.mean(bootstrap_results[1].iloc[0, :])))

Mean AUC: 0.7418214285714286


In [None]:
pd.DataFrame(bootstrap_results[1]).to_csv('chexzero_chexpert.csv')