# Sample Notebook for Zero-Shot Inference with BioViL
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [10]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Directories and Constants

In [25]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# Must... for some godforsaken reason be relative. 
cxr_filepath = '~/all_raw_data/padchest/images/44_cxr.h5' # h5 chest x-ray images
cxr_png_folder = '~/all_raw_data/padchest/images/' # folder with pngs
cxr_true_labels_path: Optional[str] = '~/all_raw_data/padchest/44_cxr_labels.csv' # labels
model_dir = None # No model_dir
predictions_dir = Path('./predictions/') # predictions
cache_dir = predictions_dir / "cached" # cache of ensembled predictions
context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                         'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                         'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                         'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# Sanity check to make sure our global variables are good
print('cxrs:', cxr_filepath)
print('labs:', cxr_true_labels_path)
print('model_dir:', 'None, for now')
print('predictions_dir:', predictions_dir)
print('predications_dir_cached:', cache_dir)
print('label names:', cxr_labels)
print('context_length:', context_length)
print('cxr_png_folder:', cxr_png_folder)

RESIZE = 512
CENTER_CROP_SIZE = 512

cxrs: ~/all_raw_data/padchest/images/44_cxr.h5
labs: ~/all_raw_data/padchest/44_cxr_labels.csv
model_dir: None, for now
predictions_dir: predictions
predications_dir_cached: predictions/cached
label names: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
context_length: 77
cxr_png_folder: ~/all_raw_data/padchest/images/


## Run Inference

In [26]:

from pathlib import Path

from health_multimodal.text import get_cxr_bert_inference
from health_multimodal.image import get_biovil_resnet_inference
from health_multimodal.vlp import ImageTextInferenceEngine
from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map

from PIL import Image

In [27]:
# === Runs softmax eval for the biovil dataset
def run_softmax_eval(model, loader, eval_labels: list, pair_template: tuple): 
    """
    Run softmax evaluation to obtain a single prediction from the model.
    Inputs:
        Model: ImageTextInferenceEngine
        loader: (todo), DataSetLoader
        eval_labels: list(str)
        pair_template: (todo), tuple(str, str)
    Outputs:
        preds: list(list) where the first index is the row, and the second index is the label,
        the value stored there is the softmax 'probability' of that condition
    """
     # get pos and neg phrases
    pos = pair_template[0]
    neg = pair_template[1]
    
    eval_labels = pd.read_csv(eval_labels)
    
    preds= []
    
    for i, row in eval_labels.iterrows():
        old_path = f"{cxr_png_folder}/{row['ImageID']}"
        new_path = f"{old_path.replace('.png', '.jpg')}"
        Image.open(old_path).convert('RGB').save(new_path)
        
        pred_labels = []
        
        for label in eval_labels.columns[2:5]:    
            
            positive_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'{label}')
            
            negative_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'no {label}')
            
            sum_pred = np.exp(positive_score) + np.exp(negative_score)
            prob = np.exp(positive_score) / sum_pred
            
            pred_labels.append(prob)
        preds.append(pred_labels)
        
        print(i, end='')
        if i > 5:
            return preds
    return preds



In [28]:

## Run the model on the data set using ensembled models
def ensemble_models(
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Input: 
        cxr_filepath: (str) path to h5
        cxr_labels: list(str) path to labels
        cxr_pair_template: tuple (template prompt based on the labels)
    Output:
        pred: list(list) of preds for the cxr_labels
    """

    # Get the biovil models
    text_inference = get_cxr_bert_inference()
    image_inference = get_biovil_resnet_inference()
    image_text_inference = ImageTextInferenceEngine(
        image_inference_engine=image_inference,
        text_inference_engine=text_inference,)
    
    
    y_pred = run_softmax_eval(image_text_inference, None, cxr_labels, cxr_pair_template)
       
    return y_pred

pred = ensemble_models(cxr_filepath, cxr_true_labels_path, cxr_pair_template)

print(pred)

Using downloaded and verified file: /tmp/biovil_image_resnet50_proj_size_128.pt


FileNotFoundError: [Errno 2] No such file or directory: '~/all_raw_data/padchest/images//216840111366964012989926673512011103121309246_00-185-142.png'

In [18]:
for p in pred:
    print(p)

NameError: name 'pred' is not defined

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [None]:
# make test_true
test_pred = y_pred_avg
print('path', cxr_true_labels_path, 'labs', cxr_labels)

test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
print(test_true.sum(), len(test_true), test_true.shape, test_pred.shape)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

path /home/ubuntu/cs197_initial_code_submission/CheXzero/data/padchest/2_cxr_labels.csv labs ['humeral fracture']
4 2968 (2968, 1) ()


IndexError: tuple index out of range

In [None]:
# display AUC with confidence intervals
bootstrap_results[1]

NameError: name 'bootstrap_results' is not defined