In [None]:
#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

In [None]:
import tempfile
from pathlib import Path

import torch
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoModel, AutoTokenizer

from health_multimodal.text.inference_engine import TextInferenceEngine
from health_multimodal.image.inference_engine import ImageInferenceEngine
from health_multimodal.vlp.inference_engine import ImageTextInferenceEngine

from health_multimodal.image.model.model import ImageModel
from health_multimodal.image.data.transforms import create_chest_xray_transform_for_inference

from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map


In [None]:
torch.cuda.is_available()

In [None]:
# Load the text inference engine
URL = "microsoft/BiomedVLP-CXR-BERT-specialized"
text_inference = TextInferenceEngine(
    tokenizer=AutoTokenizer.from_pretrained(URL, trust_remote_code=True),
    text_model=AutoModel.from_pretrained(URL, trust_remote_code=True))

# Load the image inference engine
PRETRAINED_RESNET = "/home/ozoktay/workspace/hi-ml/multimodal/health_multimodal/checkpoints/biovil_image_resnet50_proj_size_128.pt"
image_inference = ImageInferenceEngine(
    image_model=ImageModel(img_model_type="resnet50", joint_feature_size=128, pretrained_model_path=PRETRAINED_RESNET),
    transforms=create_chest_xray_transform_for_inference(resize=512, center_crop_size=480))

# Instantiate the joint inference engine
image_text_inference = ImageTextInferenceEngine(image_inference_engine=image_inference,
                                                text_inference_engine=text_inference)


In [None]:

def plot_phrase_grounding(image_path, text_prompt):
    sim_map = image_text_inference.get_similarity_map_from_raw_data(image_path=image_path, query_text=text_prompt)
    plot_phrase_grounding_similarity_map(image_path=image_path, similarity_map=sim_map)

def plot_phrase_grounding_from_url(image_url, text_prompt):
    image_path = Path(tempfile.tempdir, 'downloaded_chest_xray.jpg')
    !curl -s -L -o {image_path} {image_url}
    plot_phrase_grounding(image_path, text_prompt)


In [None]:
text_prompt = "Pneumonia in the right lung"
image_url = "https://prod-images-static.radiopaedia.org/images/1371188/0a1f5edc85aa58d5780928cb39b08659c1fc4d6d7c7dce2f8db1d63c7c737234_gallery.jpeg"
plot_phrase_grounding_from_url(image_url, text_prompt)


In [None]:
text_prompt = "small right pleural effusion"
image_path = "/datasetdrive/MIMIC-CXR-V2-512-NIFTI/files/p15/p15881002/s59658268/ed730ed6-e391f6a6-55e52913-a66b2844-da028e10.nii.gz"
plot_phrase_grounding(image_path=image_path, text_prompt=text_prompt)
