# Sample Notebook for Zero-Shot Inference with BioViL
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [1]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [2]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# Must... for some godforsaken reason be relative. 
cxr_filepath = '~/all_raw_data/padchest/images/44_cxr.h5' # h5 chest x-ray images
cxr_png_folder = '/home/ec2-user/CHEXLOCALIZE/CheXpert/' # folder with pngs
cxr_true_labels_path: Optional[str] = '/home/ec2-user/CHEXLOCALIZE/CheXpert/test_labels.csv' # labels
model_dir = None # No model_dir
predictions_dir = Path('./predictions/') # predictions
cache_dir = predictions_dir / "cached" # cache of ensembled predictions
context_length: int = 77

# ------- LABELS ------  #
# # Define labels to query each image | will return a prediction for each label
# cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
#                          'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
#                          'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
#                          'Pneumothorax', 'Support Devices']
cxr_labels = ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
                'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia',
                'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other',
                'Fracture', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# # Sanity check to make sure our global variables are good
# print('labs:', cxr_true_labels_path)
# print('model_dir:', 'None, for now')
# print('predictions_dir:', predictions_dir)
# print('predications_dir_cached:', cache_dir)
# print('label names:', cxr_labels)
# print('context_length:', context_length)
# print('cxr_png_folder:', cxr_png_folder)

RESIZE = 512
CENTER_CROP_SIZE = 512

## Run Inference

In [3]:

from pathlib import Path

from health_multimodal.text import get_cxr_bert_inference
from health_multimodal.image import get_biovil_resnet_inference
from health_multimodal.vlp import ImageTextInferenceEngine
from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map

from PIL import Image

In [4]:
# === Runs softmax eval for the biovil dataset
def run_softmax_eval(model, loader, cxr_true_labels_path, eval_labels: list, pair_template: tuple): 
    """
    Run softmax evaluation to obtain a single prediction from the model.
    Inputs:
        Model: ImageTextInferenceEngine
        loader: (todo), DataSetLoader
        eval_labels: list(str)
        pair_template: (todo), tuple(str, str)
    Outputs:
        preds: list(list) where the first index is the row, and the second index is the label,
        the value stored there is the softmax 'probability' of that condition
    """
     # get pos and neg phrases
    pos = pair_template[0]
    neg = pair_template[1]
    
    gt = pd.read_csv(cxr_true_labels_path)
    
    preds= []
    
    for i, row in gt.iterrows():
        new_path = f"{cxr_png_folder}/{row['Path']}"
        # new_path = f"{old_path.replace('.png', '.jpg')}"
        # Image.open(old_path).convert('RGB').save(new_path)
        
        pred_labels = []
        
        for label in eval_labels:    
            
            positive_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'{label}')
            
            negative_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'no {label}')
            
            sum_pred = np.exp(positive_score) + np.exp(negative_score)
            prob = np.exp(positive_score) / sum_pred
            
            pred_labels.append(prob)
        preds.append(pred_labels)
        
        if i % 100 == 0:
            print(f"Finished {i} images")
    return preds

In [5]:
import torch

## Run the model on the data set using ensembled models
def ensemble_models(
    cxr_filepath: str, 
    cxr_true_labels_path,
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Input: 
        cxr_filepath: (str) path to h5
        cxr_labels: list(str) path to labels
        cxr_pair_template: tuple (template prompt based on the labels)
    Output:
        pred: list(list) of preds for the cxr_labels
    """

    # Get the biovil models
    text_inference = get_cxr_bert_inference()
    image_inference = get_biovil_resnet_inference()
    image_text_inference = ImageTextInferenceEngine(
        image_inference_engine=image_inference,
        text_inference_engine=text_inference,)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_text_inference.to(device)
    
    y_pred = run_softmax_eval(image_text_inference, None, cxr_true_labels_path, cxr_labels, cxr_pair_template)
       
    return y_pred

pred = ensemble_models(cxr_filepath, cxr_true_labels_path, cxr_labels, cxr_pair_template)

Using downloaded and verified file: /tmp/biovil_image_resnet50_proj_size_128.pt


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [10]:
# make test_true
test_pred = np.array(pred)
print('path', cxr_true_labels_path, 'labs', cxr_labels)

test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
print(test_true.sum(), len(test_true), test_true.shape, test_pred.shape)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

path /home/ec2-user/CHEXLOCALIZE/CheXpert/test_labels.csv labs ['No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices']
1677.0 668 (668, 14) (668, 14)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [11]:
# display AUC with confidence intervals
bootstrap_results[1]

Unnamed: 0,No Finding_auc,Enlarged Cardiomediastinum_auc,Cardiomegaly_auc,Lung Opacity_auc,Lung Lesion_auc,Edema_auc,Consolidation_auc,Pneumonia_auc,Atelectasis_auc,Pneumothorax_auc,Pleural Effusion_auc,Pleural Other_auc,Fracture_auc,Support Devices_auc
mean,0.1283,0.8092,0.8228,0.8904,0.7086,0.805,0.6598,0.7879,0.6753,0.827,0.8801,0.5989,0.5463,0.6113
lower,0.0951,0.7755,0.7876,0.8634,0.5888,0.7487,0.5734,0.611,0.6326,0.6574,0.8476,0.4669,0.2236,0.5606
upper,0.1656,0.8421,0.8562,0.9149,0.8243,0.8529,0.7373,0.9206,0.7149,0.9199,0.9093,0.709,0.817,0.6559


In [15]:
pd.DataFrame(bootstrap_results[1]).to_csv('biovil_chexpert.csv')