In [None]:
#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

# Phrase grounding

This notebook demonstrates the multimodal models introduced in our ECCV 2022 paper:

> Boecking, B., Usuyama, N., Bannur, S., Castro, D., Schwaighofer, A., Hyland, S., Wetscherek, M., Naumann, T., Nori, A., Alvarez-Valle, J., Poon, H., & Oktay, O. (2022). *Making the Most of Text Semantics to Improve Biomedical Visionâ€“Language Processing* ([preprint](https://arxiv.org/abs/2204.09817))

Given a chest X-ray and a text prompt, the joint model grounds the phrase in the image, i.e., highlights the regions of the image that share features similar to the phrase.

It can be run on Binder without the need of any coding or local installation:

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/microsoft/hi-ml/HEAD?labpath=hi-ml-multimodal%2Fnotebooks%2Fphrase_grounding.ipynb)

## Setup

Let's first install the `hi-ml-multimodal` Python package, which will allow us to import the `health_multimodal` Python module.

In [None]:
repo_branch = "main"

In [None]:
repo_url = "git+https://github.com/microsoft/hi-ml.git"
subdirectory = "hi-ml-multimodal"
pip_source = f"{repo_url}@{repo_branch}#subdirectory={subdirectory}"
%pip install --quiet {pip_source}

In [None]:
import tempfile
from pathlib import Path

from health_multimodal.text import get_cxr_bert_inference
from health_multimodal.image import get_cxr_resnet_inference
from health_multimodal.vlp import ImageTextInferenceEngine
from health_multimodal.common.visualization import plot_phrase_grounding_similarity_map

## Load multimodal model

Load the text and image models from [Hugging Face ðŸ¤—](https://aka.ms/biovil-models) and instantiate the inference engines:

In [None]:
text_inference = get_cxr_bert_inference()
image_inference = get_cxr_resnet_inference()

Instantiate the joint inference engine:

In [None]:
image_text_inference = ImageTextInferenceEngine(
    image_inference_engine=image_inference,
    text_inference_engine=text_inference,
)

## Helper visualization functions

In [None]:
def plot_phrase_grounding(image_path: Path, text_prompt: str) -> None:
    similarity_map = image_text_inference.get_similarity_map_from_raw_data(
        image_path=image_path,
        query_text=text_prompt,
        interpolation="bilinear",
    )
    plot_phrase_grounding_similarity_map(
        image_path=image_path,
        similarity_map=similarity_map,
    )

def plot_phrase_grounding_from_url(image_url: str, text_prompt: str) -> None:
    image_path = Path(tempfile.tempdir, "downloaded_chest_xray.jpg")
    !curl -s -L -o {image_path} {image_url}
    plot_phrase_grounding(image_path, text_prompt)

## Inference

We will run inference on a chest X-ray from [Radiopaedia](https://radiopaedia.org/), but any can be used.

In [None]:
text_prompt = "Pneumonia in the right lung"
image_url = "https://prod-images-static.radiopaedia.org/images/1371188/0a1f5edc85aa58d5780928cb39b08659c1fc4d6d7c7dce2f8db1d63c7c737234_gallery.jpeg"

In [None]:
plot_phrase_grounding_from_url(image_url, text_prompt)