# Sample Notebook for Zero-Shot Inference with CheXzero
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [2]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# # Padchest
# cxr_filepath: str = '/home/ec2-user/all_raw_data/padchest/images/44_cxr.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/padchest/44_cxr_labels.csv' # (optional for evaluation) if labels are provided, provide path

# CheXzero test
cxr_filepath: str = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test.h5' # filepath of chest x-ray images (.h5)
cxr_true_labels_path: Optional[str] = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test_labels_view1.csv' # (optional for evaluation) if labels are provided, provide path

# # CheXzero val
# cxr_filepath: str = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid/chexpert_val.h5' # filepath of chest x-ray images (.h5)
# cxr_true_labels_path: Optional[str] = '/home/ec2-user/all_raw_data/chexpert/CheXpert-v1.0-small/valid_view1.csv' # (optional for evaluation) if labels are provided, provide path


model_dir: str = '../checkpoints/chexzero_baseline' # where pretrained models are saved (.pt) 
predictions_dir: Path = Path('../predictions') # where to save predictions
cache_dir: str = predictions_dir / "cached" # where to cache ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
# cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
#                                       'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
#                                       'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
#                                       'Pneumothorax', 'Support Devices']
cxr_labels = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
                'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other',
                'Fracture', 'Support Devices']
# cxr_labels = [label.lower() for label in cxr_labels]

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
model_paths = []
for subdir, dirs, files in os.walk(model_dir):
    for file in files:
        full_dir = os.path.join(subdir, file)
        model_paths.append(full_dir)

n = len(model_paths)
# model_paths = model_paths[0:n:3]
print(model_paths)

['../checkpoints/chexzero_baseline/checkpoint_5000.pt', '../checkpoints/chexzero_baseline/checkpoint_10000.pt', '../checkpoints/chexzero_baseline/checkpoint_15000.pt', '../checkpoints/chexzero_baseline/checkpoint_20000.pt', '../checkpoints/chexzero_baseline/checkpoint_25000.pt', '../checkpoints/chexzero_baseline/checkpoint_30000.pt', '../checkpoints/chexzero_baseline/checkpoint_35000.pt', '../checkpoints/chexzero_baseline/checkpoint_40000.pt', '../checkpoints/chexzero_baseline/checkpoint_45000.pt', '../checkpoints/chexzero_baseline/checkpoint.pt']


## Run Inference

In [3]:
## Run the model on the data set using ensembled models
def ensemble_models(
    model_paths: List[str], 
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
    cache_dir: str = None, 
    save_name: str = None,
    change_text_encoder: bool = False,
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Given a list of `model_paths`, ensemble model and return
    predictions. Caches predictions at `cache_dir` if location provided.

    Returns a list of each model's predictions and the averaged
    set of predictions.
    """

    predictions = []
    model_paths = sorted(model_paths) # ensure consistency of 
    for path in model_paths: # for each model
        model_name = Path(path).stem

        # load in model and `torch.DataLoader`
        model, loader = make(
            model_path=path, 
            cxr_filepath=cxr_filepath, 
            change_text_encoder=change_text_encoder,
        ) 
        
        # path to the cached prediction
        if cache_dir is not None:
            if save_name is not None: 
                cache_path = Path(cache_dir) / f"{save_name}_{model_name}.npy"
            else: 
                cache_path = Path(cache_dir) / f"{model_name}.npy"

        # if prediction already cached, don't recompute prediction
        if cache_dir is not None and os.path.exists(cache_path): 
            print("Loading cached prediction for {}".format(model_name))
            y_pred = np.load(cache_path)
        else: # cached prediction not found, compute preds
            print("Inferring model {}".format(path))
            y_pred = run_softmax_eval(model, loader, cxr_labels, cxr_pair_template, change_text_encoder=change_text_encoder)
            if cache_dir is not None: 
                Path(cache_dir).mkdir(exist_ok=True, parents=True)
                np.save(file=cache_path, arr=y_pred)
        predictions.append(y_pred)
    
    # compute average predictions
    if len(predictions) > 1:
        y_pred_avg = np.mean(predictions, axis=0)
    else:
        y_pred_avg = predictions[0]
    
    return predictions, y_pred_avg

In [5]:
# make test_true
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

result = []
# Evaluate all models alone
for path in model_paths:
    predictions, y_pred_avg = ensemble_models(
        model_paths=[path], 
        cxr_filepath=cxr_filepath, 
        cxr_labels=cxr_labels, 
        cxr_pair_template=cxr_pair_template, 
        cache_dir=cache_dir,
        change_text_encoder=False,
    )

    test_pred = y_pred_avg

    # evaluate model
    cxr_results = evaluate(test_pred, test_true, cxr_labels)

    # boostrap evaluations for 95% confidence intervals
    bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)
    mAUC = np.mean(bootstrap_results[1][["Atelectasis_auc", "Cardiomegaly_auc", "Consolidation_auc", "Edema_auc", "Pleural Effusion_auc"]].iloc[0, :])

    print("Model:{},  AUC: {}".format(path, mAUC))
    result.append([path, mAUC])

print(result)
    

../checkpoints/chexzero_baseline/checkpoint_5000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_5000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_5000.pt,  AUC: 0.6293799999999999
../checkpoints/chexzero_baseline/checkpoint_10000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_10000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_10000.pt,  AUC: 0.74666
../checkpoints/chexzero_baseline/checkpoint_15000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_15000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_15000.pt,  AUC: 0.81214
../checkpoints/chexzero_baseline/checkpoint_20000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_20000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_20000.pt,  AUC: 0.8242600000000001
../checkpoints/chexzero_baseline/checkpoint_25000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_25000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_25000.pt,  AUC: 0.82028
../checkpoints/chexzero_baseline/checkpoint_30000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_30000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_30000.pt,  AUC: 0.8367199999999999
../checkpoints/chexzero_baseline/checkpoint_35000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_35000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_35000.pt,  AUC: 0.80548
../checkpoints/chexzero_baseline/checkpoint_40000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_40000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_40000.pt,  AUC: 0.81648
../checkpoints/chexzero_baseline/checkpoint_45000.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint_45000.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint_45000.pt,  AUC: 0.7944199999999999
../checkpoints/chexzero_baseline/checkpoint.pt using an online model  
Inferring model ../checkpoints/chexzero_baseline/checkpoint.pt


  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Model:../checkpoints/chexzero_baseline/checkpoint.pt,  AUC: 0.8130199999999999
[['../checkpoints/chexzero_baseline/checkpoint_5000.pt', 0.6293799999999999], ['../checkpoints/chexzero_baseline/checkpoint_10000.pt', 0.74666], ['../checkpoints/chexzero_baseline/checkpoint_15000.pt', 0.81214], ['../checkpoints/chexzero_baseline/checkpoint_20000.pt', 0.8242600000000001], ['../checkpoints/chexzero_baseline/checkpoint_25000.pt', 0.82028], ['../checkpoints/chexzero_baseline/checkpoint_30000.pt', 0.8367199999999999], ['../checkpoints/chexzero_baseline/checkpoint_35000.pt', 0.80548], ['../checkpoints/chexzero_baseline/checkpoint_40000.pt', 0.81648], ['../checkpoints/chexzero_baseline/checkpoint_45000.pt', 0.7944199999999999], ['../checkpoints/chexzero_baseline/checkpoint.pt', 0.8130199999999999]]


In [62]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
    change_text_encoder = True,
)

../checkpoints/cxr-bert/checkpoint_105000.pt using an online model  
Argument error. Set pretrained = True. <class 'RuntimeError'>


RuntimeError: Error(s) in loading state_dict for CLIP:
	Unexpected key(s) in state_dict: "text_model.bert.embeddings.position_ids", "text_model.bert.embeddings.word_embeddings.weight", "text_model.bert.embeddings.position_embeddings.weight", "text_model.bert.embeddings.token_type_embeddings.weight", "text_model.bert.embeddings.LayerNorm.weight", "text_model.bert.embeddings.LayerNorm.bias", "text_model.bert.encoder.layer.0.attention.self.query.weight", "text_model.bert.encoder.layer.0.attention.self.query.bias", "text_model.bert.encoder.layer.0.attention.self.key.weight", "text_model.bert.encoder.layer.0.attention.self.key.bias", "text_model.bert.encoder.layer.0.attention.self.value.weight", "text_model.bert.encoder.layer.0.attention.self.value.bias", "text_model.bert.encoder.layer.0.attention.output.dense.weight", "text_model.bert.encoder.layer.0.attention.output.dense.bias", "text_model.bert.encoder.layer.0.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.0.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.0.intermediate.dense.weight", "text_model.bert.encoder.layer.0.intermediate.dense.bias", "text_model.bert.encoder.layer.0.output.dense.weight", "text_model.bert.encoder.layer.0.output.dense.bias", "text_model.bert.encoder.layer.0.output.LayerNorm.weight", "text_model.bert.encoder.layer.0.output.LayerNorm.bias", "text_model.bert.encoder.layer.1.attention.self.query.weight", "text_model.bert.encoder.layer.1.attention.self.query.bias", "text_model.bert.encoder.layer.1.attention.self.key.weight", "text_model.bert.encoder.layer.1.attention.self.key.bias", "text_model.bert.encoder.layer.1.attention.self.value.weight", "text_model.bert.encoder.layer.1.attention.self.value.bias", "text_model.bert.encoder.layer.1.attention.output.dense.weight", "text_model.bert.encoder.layer.1.attention.output.dense.bias", "text_model.bert.encoder.layer.1.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.1.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.1.intermediate.dense.weight", "text_model.bert.encoder.layer.1.intermediate.dense.bias", "text_model.bert.encoder.layer.1.output.dense.weight", "text_model.bert.encoder.layer.1.output.dense.bias", "text_model.bert.encoder.layer.1.output.LayerNorm.weight", "text_model.bert.encoder.layer.1.output.LayerNorm.bias", "text_model.bert.encoder.layer.2.attention.self.query.weight", "text_model.bert.encoder.layer.2.attention.self.query.bias", "text_model.bert.encoder.layer.2.attention.self.key.weight", "text_model.bert.encoder.layer.2.attention.self.key.bias", "text_model.bert.encoder.layer.2.attention.self.value.weight", "text_model.bert.encoder.layer.2.attention.self.value.bias", "text_model.bert.encoder.layer.2.attention.output.dense.weight", "text_model.bert.encoder.layer.2.attention.output.dense.bias", "text_model.bert.encoder.layer.2.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.2.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.2.intermediate.dense.weight", "text_model.bert.encoder.layer.2.intermediate.dense.bias", "text_model.bert.encoder.layer.2.output.dense.weight", "text_model.bert.encoder.layer.2.output.dense.bias", "text_model.bert.encoder.layer.2.output.LayerNorm.weight", "text_model.bert.encoder.layer.2.output.LayerNorm.bias", "text_model.bert.encoder.layer.3.attention.self.query.weight", "text_model.bert.encoder.layer.3.attention.self.query.bias", "text_model.bert.encoder.layer.3.attention.self.key.weight", "text_model.bert.encoder.layer.3.attention.self.key.bias", "text_model.bert.encoder.layer.3.attention.self.value.weight", "text_model.bert.encoder.layer.3.attention.self.value.bias", "text_model.bert.encoder.layer.3.attention.output.dense.weight", "text_model.bert.encoder.layer.3.attention.output.dense.bias", "text_model.bert.encoder.layer.3.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.3.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.3.intermediate.dense.weight", "text_model.bert.encoder.layer.3.intermediate.dense.bias", "text_model.bert.encoder.layer.3.output.dense.weight", "text_model.bert.encoder.layer.3.output.dense.bias", "text_model.bert.encoder.layer.3.output.LayerNorm.weight", "text_model.bert.encoder.layer.3.output.LayerNorm.bias", "text_model.bert.encoder.layer.4.attention.self.query.weight", "text_model.bert.encoder.layer.4.attention.self.query.bias", "text_model.bert.encoder.layer.4.attention.self.key.weight", "text_model.bert.encoder.layer.4.attention.self.key.bias", "text_model.bert.encoder.layer.4.attention.self.value.weight", "text_model.bert.encoder.layer.4.attention.self.value.bias", "text_model.bert.encoder.layer.4.attention.output.dense.weight", "text_model.bert.encoder.layer.4.attention.output.dense.bias", "text_model.bert.encoder.layer.4.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.4.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.4.intermediate.dense.weight", "text_model.bert.encoder.layer.4.intermediate.dense.bias", "text_model.bert.encoder.layer.4.output.dense.weight", "text_model.bert.encoder.layer.4.output.dense.bias", "text_model.bert.encoder.layer.4.output.LayerNorm.weight", "text_model.bert.encoder.layer.4.output.LayerNorm.bias", "text_model.bert.encoder.layer.5.attention.self.query.weight", "text_model.bert.encoder.layer.5.attention.self.query.bias", "text_model.bert.encoder.layer.5.attention.self.key.weight", "text_model.bert.encoder.layer.5.attention.self.key.bias", "text_model.bert.encoder.layer.5.attention.self.value.weight", "text_model.bert.encoder.layer.5.attention.self.value.bias", "text_model.bert.encoder.layer.5.attention.output.dense.weight", "text_model.bert.encoder.layer.5.attention.output.dense.bias", "text_model.bert.encoder.layer.5.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.5.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.5.intermediate.dense.weight", "text_model.bert.encoder.layer.5.intermediate.dense.bias", "text_model.bert.encoder.layer.5.output.dense.weight", "text_model.bert.encoder.layer.5.output.dense.bias", "text_model.bert.encoder.layer.5.output.LayerNorm.weight", "text_model.bert.encoder.layer.5.output.LayerNorm.bias", "text_model.bert.encoder.layer.6.attention.self.query.weight", "text_model.bert.encoder.layer.6.attention.self.query.bias", "text_model.bert.encoder.layer.6.attention.self.key.weight", "text_model.bert.encoder.layer.6.attention.self.key.bias", "text_model.bert.encoder.layer.6.attention.self.value.weight", "text_model.bert.encoder.layer.6.attention.self.value.bias", "text_model.bert.encoder.layer.6.attention.output.dense.weight", "text_model.bert.encoder.layer.6.attention.output.dense.bias", "text_model.bert.encoder.layer.6.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.6.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.6.intermediate.dense.weight", "text_model.bert.encoder.layer.6.intermediate.dense.bias", "text_model.bert.encoder.layer.6.output.dense.weight", "text_model.bert.encoder.layer.6.output.dense.bias", "text_model.bert.encoder.layer.6.output.LayerNorm.weight", "text_model.bert.encoder.layer.6.output.LayerNorm.bias", "text_model.bert.encoder.layer.7.attention.self.query.weight", "text_model.bert.encoder.layer.7.attention.self.query.bias", "text_model.bert.encoder.layer.7.attention.self.key.weight", "text_model.bert.encoder.layer.7.attention.self.key.bias", "text_model.bert.encoder.layer.7.attention.self.value.weight", "text_model.bert.encoder.layer.7.attention.self.value.bias", "text_model.bert.encoder.layer.7.attention.output.dense.weight", "text_model.bert.encoder.layer.7.attention.output.dense.bias", "text_model.bert.encoder.layer.7.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.7.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.7.intermediate.dense.weight", "text_model.bert.encoder.layer.7.intermediate.dense.bias", "text_model.bert.encoder.layer.7.output.dense.weight", "text_model.bert.encoder.layer.7.output.dense.bias", "text_model.bert.encoder.layer.7.output.LayerNorm.weight", "text_model.bert.encoder.layer.7.output.LayerNorm.bias", "text_model.bert.encoder.layer.8.attention.self.query.weight", "text_model.bert.encoder.layer.8.attention.self.query.bias", "text_model.bert.encoder.layer.8.attention.self.key.weight", "text_model.bert.encoder.layer.8.attention.self.key.bias", "text_model.bert.encoder.layer.8.attention.self.value.weight", "text_model.bert.encoder.layer.8.attention.self.value.bias", "text_model.bert.encoder.layer.8.attention.output.dense.weight", "text_model.bert.encoder.layer.8.attention.output.dense.bias", "text_model.bert.encoder.layer.8.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.8.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.8.intermediate.dense.weight", "text_model.bert.encoder.layer.8.intermediate.dense.bias", "text_model.bert.encoder.layer.8.output.dense.weight", "text_model.bert.encoder.layer.8.output.dense.bias", "text_model.bert.encoder.layer.8.output.LayerNorm.weight", "text_model.bert.encoder.layer.8.output.LayerNorm.bias", "text_model.bert.encoder.layer.9.attention.self.query.weight", "text_model.bert.encoder.layer.9.attention.self.query.bias", "text_model.bert.encoder.layer.9.attention.self.key.weight", "text_model.bert.encoder.layer.9.attention.self.key.bias", "text_model.bert.encoder.layer.9.attention.self.value.weight", "text_model.bert.encoder.layer.9.attention.self.value.bias", "text_model.bert.encoder.layer.9.attention.output.dense.weight", "text_model.bert.encoder.layer.9.attention.output.dense.bias", "text_model.bert.encoder.layer.9.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.9.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.9.intermediate.dense.weight", "text_model.bert.encoder.layer.9.intermediate.dense.bias", "text_model.bert.encoder.layer.9.output.dense.weight", "text_model.bert.encoder.layer.9.output.dense.bias", "text_model.bert.encoder.layer.9.output.LayerNorm.weight", "text_model.bert.encoder.layer.9.output.LayerNorm.bias", "text_model.bert.encoder.layer.10.attention.self.query.weight", "text_model.bert.encoder.layer.10.attention.self.query.bias", "text_model.bert.encoder.layer.10.attention.self.key.weight", "text_model.bert.encoder.layer.10.attention.self.key.bias", "text_model.bert.encoder.layer.10.attention.self.value.weight", "text_model.bert.encoder.layer.10.attention.self.value.bias", "text_model.bert.encoder.layer.10.attention.output.dense.weight", "text_model.bert.encoder.layer.10.attention.output.dense.bias", "text_model.bert.encoder.layer.10.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.10.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.10.intermediate.dense.weight", "text_model.bert.encoder.layer.10.intermediate.dense.bias", "text_model.bert.encoder.layer.10.output.dense.weight", "text_model.bert.encoder.layer.10.output.dense.bias", "text_model.bert.encoder.layer.10.output.LayerNorm.weight", "text_model.bert.encoder.layer.10.output.LayerNorm.bias", "text_model.bert.encoder.layer.11.attention.self.query.weight", "text_model.bert.encoder.layer.11.attention.self.query.bias", "text_model.bert.encoder.layer.11.attention.self.key.weight", "text_model.bert.encoder.layer.11.attention.self.key.bias", "text_model.bert.encoder.layer.11.attention.self.value.weight", "text_model.bert.encoder.layer.11.attention.self.value.bias", "text_model.bert.encoder.layer.11.attention.output.dense.weight", "text_model.bert.encoder.layer.11.attention.output.dense.bias", "text_model.bert.encoder.layer.11.attention.output.LayerNorm.weight", "text_model.bert.encoder.layer.11.attention.output.LayerNorm.bias", "text_model.bert.encoder.layer.11.intermediate.dense.weight", "text_model.bert.encoder.layer.11.intermediate.dense.bias", "text_model.bert.encoder.layer.11.output.dense.weight", "text_model.bert.encoder.layer.11.output.dense.bias", "text_model.bert.encoder.layer.11.output.LayerNorm.weight", "text_model.bert.encoder.layer.11.output.LayerNorm.bias", "text_model.cls.predictions.bias", "text_model.cls.predictions.transform.dense.weight", "text_model.cls.predictions.transform.dense.bias", "text_model.cls.predictions.transform.LayerNorm.weight", "text_model.cls.predictions.transform.LayerNorm.bias", "text_model.cls.predictions.decoder.weight", "text_model.cls.predictions.decoder.bias", "text_model.cls_projection_head.dense_to_hidden.weight", "text_model.cls_projection_head.dense_to_hidden.bias", "text_model.cls_projection_head.LayerNorm.weight", "text_model.cls_projection_head.LayerNorm.bias", "text_model.cls_projection_head.dense_to_output.weight", "text_model.cls_projection_head.dense_to_output.bias", "text_model_linear.weight", "text_model_linear.bias". 

In [7]:
# save averaged preds
pred_name = "chexpert_preds_bert.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [40]:
# make test_true
test_pred = y_pred_avg
test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [48]:
# display AUC with confidence intervals
bootstrap_results[1][["Atelectasis_auc", "Cardiomegaly_auc", "Consolidation_auc", "Edema_auc", "Pleural Effusion_auc"]]

Unnamed: 0,Atelectasis_auc,Cardiomegaly_auc,Consolidation_auc,Edema_auc,Pleural Effusion_auc
mean,0.5341,0.7555,0.6241,0.7631,0.5203
lower,0.4791,0.709,0.5144,0.7115,0.4555
upper,0.5873,0.7973,0.7251,0.8132,0.584


In [49]:
# top 5 competition pathologies
print("Mean AUC: {}".format(
    np.mean(bootstrap_results[1][["Atelectasis_auc", "Cardiomegaly_auc", "Consolidation_auc", "Edema_auc", "Pleural Effusion_auc"]].iloc[0, :])
    ))

Mean AUC: 0.63942


In [11]:
pd.DataFrame(bootstrap_results[1]).to_csv('chexzero_chexpert.csv')