# Sample Notebook for Zero-Shot Inference with BioViL
This notebook walks through how to use CheXzero to perform zero-shot inference on a chest x-ray image dataset.

## Import Libraries

In [4]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Optional

import sys
sys.path.append('../')

from eval import evaluate, bootstrap
from zero_shot import make, make_true_labels, run_softmax_eval

%load_ext autoreload
%autoreload 2

## Directories and Constants

In [77]:
## Define Zero Shot Labels and Templates

# ----- DIRECTORIES ------ #
# Must... for some godforsaken reason be relative. 
cxr_filepath = '../../../../../../cs197_initial_code_submission/CheXzero/data/padchest/images/44_cxr.h5' # h5 chest x-ray images
cxr_png_folder = '../../../../../../cs197_initial_code_submission/AllRawData/padchest/44/' # folder with pngs
cxr_true_labels_path: Optional[str] = '../../../../../../cs197_initial_code_submission/CheXzero/data/padchest/44_cxr_labels.csv' # labels

model_dir = None # No model_dir

predictions_dir = Path('./predictions/') # predictions
cache_dir = predictions_dir / "cached" # cache of ensembled predictions

context_length: int = 77

# ------- LABELS ------  #
# Define labels to query each image | will return a prediction for each label
cxr_labels: List[str] = ['Atelectasis','Cardiomegaly', 
                         'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                         'Lung Opacity', 'No Finding','Pleural Effusion', 'Pleural Other', 'Pneumonia', 
                         'Pneumothorax', 'Support Devices']

# ---- TEMPLATES ----- # 
# Define set of templates | see Figure 1 for more details                        
cxr_pair_template: Tuple[str] = ("{}", "no {}")

# ----- MODEL PATHS ------ #
# If using ensemble, collect all model paths
# model_paths = []
# for subdir, dirs, files in os.walk(model_dir):
#     for file in files:
#         full_dir = os.path.join(subdir, file)
#         model_paths.append(full_dir)
        
# print(model_paths)

# Sanity check to make sure our global variables are good
print('cxrs:', cxr_filepath)
print('labs:', cxr_true_labels_path)
print('model_dir:', 'None, for now')
print('predictions_dir:', predictions_dir)
print('predications_dir_cached:', cache_dir)
print('label names:', cxr_labels)
print('context_length:', context_length)

RESIZE = 512
CENTER_CROP_SIZE = 512

cxrs: ../../../../../../cs197_initial_code_submission/CheXzero/data/padchest/images/44_cxr.h5
labs: ../../../../../../cs197_initial_code_submission/CheXzero/data/padchest/44_cxr_labels.csv
model_dir: None, for now
predictions_dir: predictions
predications_dir_cached: predictions/cached
label names: ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']
context_length: 77


## Run Inference

In [78]:
import tempfile
from pathlib import Path

import torch

from health_multimodal.text import get_cxr_bert_inference
from health_multimodal.image import get_biovil_resnet_inference
from health_multimodal.vlp import ImageTextInferenceEngine
from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map

In [79]:
# # === Makes a DataLoader ===
# from torchvision.transforms import Compose, Normalize, Resize, InterpolationMode
# from zero_shot import CXRTestDataset
# import h5py
# from PIL import Image

# def make_data_loader(cxr_h5_path):
#     transformations = [
#         Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)),
#         Resize(224, interpolation=InterpolationMode.BICUBIC)
#     ]
#     transform = Compose(transformations)
    
#     dataset = CXRTestDataset(
#         img_path=cxr_h5_path,
#         transform=transform, 
#     )
    
#     return torch.utils.data.DataLoader(dataset, shuffle=False)

# make_data_loader(cxr_filepath)

# def save_jpeg(cxr_true_labels_path):
    
#     loader = make_data_loader(cxr_h5_path)
#     for i, data in enumerate(loader):
    
#         images = np.transpose(data['img'].detach().numpy(), (0, 2, 3, 1)).squeeze() # (1, 3, 224, 224) => (1, 224, 224, 3) => (224, 224, 3)
        
#         pil_img = Image.fromarray(obj=images, mode='RGB')
#         pil_img.show()
                
# save_jpeg(cxr_filepath)

NameError: name 'cxr_h5_path' is not defined

In [106]:
# === Runs  softmax eval
def run_softmax_eval(model, loader, eval_labels: list, pair_template: tuple, context_length: int = 77): 
    """
    Run softmax evaluation to obtain a single prediction from the model.
    """
     # get pos and neg phrases
    pos = pair_template[0]
    neg = pair_template[1]
    
    eval_labels = pd.read_csv(eval_labels)
    
    preds= []
    
    for i, row in eval_labels.iterrows():
        old_path = f"{cxr_png_folder}/{row['ImageID']}"
        new_path = f"{old_path.replace('.png', '.jpg')}"
        Image.open(old_path).convert('RGB').save(new_path)
        
        pred_labels = []
        
        for label in eval_labels[:5]:    
            
            positive_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'{label}')
            
            negative_score = model.get_similarity_score_from_raw_data(
                image_path=Path(new_path),
                query_text=f'no {label}')
            
            sum_pred = np.exp(positive_score) + np.exp(negative_score)
            prob = np.exp(positive_score) / sum_pred
            
            pred_labels.append(prob)
        preds.append(pred_labels)
        
        print(i, end='')
        if i > 50:
            return
    return preds



In [107]:

## Run the model on the data set using ensembled models
def ensemble_models(
    cxr_filepath: str, 
    cxr_labels: List[str], 
    cxr_pair_template: Tuple[str], 
) -> Tuple[List[np.ndarray], np.ndarray]: 
    """
    Input: 
        -path to h5 and labels
        -template (to make a prompt based on the labels)
    Output:
        -predictions of labels of the h5
    """

    # Get the biovil models
    text_inference = get_cxr_bert_inference()
    image_inference = get_biovil_resnet_inference()
    image_text_inference = ImageTextInferenceEngine(
        image_inference_engine=image_inference,
        text_inference_engine=text_inference,)
    
    
    y_pred = run_softmax_eval(image_text_inference, None, cxr_labels, cxr_pair_template)
       
    return y_pred

ensemble_models(cxr_filepath, cxr_true_labels_path, cxr_pair_template)

Using downloaded and verified file: /tmp/biovil_image_resnet50_proj_size_128.pt


In [74]:
predictions, y_pred_avg = ensemble_models(
    model_paths=model_paths, 
    cxr_filepath=cxr_filepath, 
    cxr_labels=cxr_labels, 
    cxr_pair_template=cxr_pair_template, 
    cache_dir=cache_dir,
)

print(len(predictions))

NameError: name 'model_paths' is not defined

In [90]:
# save averaged preds
pred_name = "chexpert_preds.npy" # add name of preds
predictions_dir = predictions_dir / pred_name
np.save(file=predictions_dir, arr=y_pred_avg)

## (Optional) Evaluate Results
If ground truth labels are available, compute AUC on each pathology to evaluate the performance of the zero-shot model. 

In [91]:
# make test_true
test_pred = y_pred_avg
print('path', cxr_true_labels_path, 'labs', cxr_labels)

test_true = make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
print(test_true.sum(), len(test_true), test_true.shape, test_pred.shape)

# evaluate model
cxr_results = evaluate(test_pred, test_true, cxr_labels)

# boostrap evaluations for 95% confidence intervals
bootstrap_results = bootstrap(test_pred, test_true, cxr_labels)

path /home/ubuntu/cs197_initial_code_submission/CheXzero/data/padchest/2_cxr_labels.csv labs ['humeral fracture']
4 2968 (2968, 1) ()


IndexError: tuple index out of range

In [17]:
# display AUC with confidence intervals
bootstrap_results[1]

NameError: name 'bootstrap_results' is not defined