# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

In [1]:
!pwd

/home/dk58319/private/CheXzero/notebooks


## Import Libraries

In [2]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "8"
import random
from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [3]:
def get_pair_template():
    pos = "{}"
    
    # ratio = random.random()
    neg = random.choice(["not {}","no evidence of {}", "no signs of {}", "no {}"])    

    return (pos, neg)

In [4]:
test_template = get_pair_template()
print(test_template)

('{}', 'no {}')


In [5]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
cxr_filepath: str = '../data/test/cxr_test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '../data/cheXpert/chexlocalize/CheXpert/test_labels.csv' # (optional for evaluation) if labels are provided, provide path
model_dir: str = '../checkpoints/chexzero_weights' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions


context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                                      'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                                      'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                                      'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
# cxr_pair_template: Tuple[str] = ("{}", "no {}")
cxr_pair_template: Tuple[str] = get_pair_template()


# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)
        
print(model_paths)

['../checkpoints/chexzero_weights/best_64_5e-05_original_16000_0.858.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_16000_0.861.pt', '../checkpoints/chexzero_weights/best_128_5e-05_original_22000_0.855.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_8000_0.857.pt', '../checkpoints/chexzero_weights/best_64_0.0002_original_23000_0.854.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_17000_0.863.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt', '../checkpoints/chexzero_weights/best_128_0.0002_original_15000_0.859.pt', '../checkpoints/chexzero_weights/best_64_0.0001_original_35000_0.864.pt', '../checkpoints/chexzero_weights/best_64_5e-05_original_18000_0.862.pt']


In [6]:
CHEXPERT_TEST_CSV_PATH = Path("/home/dk58319/shared/hdd_ext/nvme1/public/medical/classification/chest/cheXpert/chexlocalize/CheXpert/test_labels.csv")
H5_PATH = Path("/home/dk58319/private/CheXzero/data/test/cxr_test.h5")
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly',
                                      'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                                      'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia',
                                      'Pneumothorax', 'Support Devices']

def get_labels_from_csv(csv_path: Path)->List[str]:
    path = []
    labels = []
    df = pd.read_csv(csv_path)
    label_columns = df.columns[1:]
    for i in range(len(df)):
        path.append(df['Path'][i])
        label = []
        for column in label_columns:
            if df[column][i] == 1:
                label.append(column)
        labels.append(label)
    return path, labels

def get_labels_from_hdf5(hdf5_path: Path)->List[str]:
    with h5py.File(hdf5_path, 'r') as f:
        labels = f['labels'][:]
    return labels

def make_positive_negative_prompt(labels:List) -> Tuple[List, List]:
    """
    make positive and negative prompts for each label
    """
    pos_query = [
        'Findings consistent with {}',
        'Findings suggesting {}',
        'This opacity can represent {}',
        'Findings are most compatible with {}',
    ]
    neg_query = [
        'There is no {}',
        'No evidence of {}',
        'No evidence of acute {}',
        'No signs of {}',
    ]
    for label in labels:
        for i in label:
            for j in range(4):
                pos_query.append(pos_query[j].format(i))
                neg_query.append(neg_query[j].format(i))



    return pos_query, neg_query


In [7]:
print(model_paths[6])

../checkpoints/chexzero_weights/best_64_5e-05_original_22000_0.864.pt


## Run Inference

In [8]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
        ) 

        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template)
            # # y_pred = run_softmax_eval2(model, loader, cxr_labels, cxr_pair_template)
            # y_pred = run_softmax_eval3(model,loader, cxr_labels, cxr_pair_template)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    y_pred_avg = np.mean(predictions, axis=0)
    
    return predictions, y_pred_avg

In [9]:
!ls

zero_shot.ipynb


In [10]:
predictions, y_pred_avg = ensemble_models(
    model_paths=[model_paths[6]], 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
)



Loading cached prediction for best_64_5e-05_original_22000_0.864


In [11]:
# save averaged preds
pred_name = "chexpert_preds.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [12]:
# make test_true
test_pred = y_pred_avg

test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
# print(test_true)
# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [13]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,Atelectasis_auc,Cardiomegaly_auc,Consolidation_auc,Edema_auc,Enlarged Cardiomediastinum_auc,Fracture_auc,Lung Lesion_auc,Lung Opacity_auc,No Finding_auc,Pleural Effusion_auc,Pleural Other_auc,Pneumonia_auc,Pneumothorax_auc,Support Devices_auc
mean,0.7991,0.8953,0.8865,0.9028,0.8556,0.5413,0.7108,0.9148,0.0798,0.9279,0.5672,0.7794,0.6201,0.8345
lower,0.7577,0.8655,0.8224,0.8702,0.825,0.2368,0.5571,0.8868,0.0559,0.9008,0.4558,0.5749,0.4145,0.7992
upper,0.8364,0.9224,0.9389,0.9322,0.887,0.8236,0.8552,0.9377,0.1082,0.9508,0.7384,0.9458,0.8343,0.8672


In [14]:
print(len(test_true))
print(test_true.shape)
print(test_true[0])

500
(500, 14)
[0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 1.]


In [15]:
bootstrap_results[1]

Unnamed: 0,Atelectasis_auc,Cardiomegaly_auc,Consolidation_auc,Edema_auc,Enlarged Cardiomediastinum_auc,Fracture_auc,Lung Lesion_auc,Lung Opacity_auc,No Finding_auc,Pleural Effusion_auc,Pleural Other_auc,Pneumonia_auc,Pneumothorax_auc,Support Devices_auc
mean,0.7991,0.8953,0.8865,0.9028,0.8556,0.5413,0.7108,0.9148,0.0798,0.9279,0.5672,0.7794,0.6201,0.8345
lower,0.7577,0.8655,0.8224,0.8702,0.825,0.2368,0.5571,0.8868,0.0559,0.9008,0.4558,0.5749,0.4145,0.7992
upper,0.8364,0.9224,0.9389,0.9322,0.887,0.8236,0.8552,0.9377,0.1082,0.9508,0.7384,0.9458,0.8343,0.8672


In [16]:
df = bootstrap_results[0]

print(type(df))

print(df.columns)

df = df[['Atelectasis_auc','Cardiomegaly_auc','Consolidation_auc','Edema_auc','Pleural Effusion_auc']]

df.describe().T.mean()

<class 'pandas.core.frame.DataFrame'>
Index(['Atelectasis_auc', 'Cardiomegaly_auc', 'Consolidation_auc', 'Edema_auc',
       'Enlarged Cardiomediastinum_auc', 'Fracture_auc', 'Lung Lesion_auc',
       'Lung Opacity_auc', 'No Finding_auc', 'Pleural Effusion_auc',
       'Pleural Other_auc', 'Pneumonia_auc', 'Pneumothorax_auc',
       'Support Devices_auc'],
      dtype='object')


count    1000.000000
mean        0.882335
std         0.018791
min         0.810422
25%         0.870346
50%         0.883396
75%         0.895206
max         0.939561
dtype: float64