In [108]:
import PIL
import matplotlib.pyplot as plt
import numpy as np

from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from keras import Sequential, layers

In [2]:
MODEL_CHECKPOINT = "mmenendezg/vit_pneumonia_classifier"
IMG_SIZE = [224, 224]
IMG_CLASSES = ["Normal", "Pneumonia"]
THRESHOLD = 0.65

In [3]:
def process_attention(
    attention_vals: tf.Tensor,
    n_heads: int,
    h_featmap: int,
    w_featmap: int,
    h_original: int,
    w_original: int,
):
    """
    Process the attention weights for visualization.

    Args:
        attention_vals: The attention weights.
        n_heads: The number of attention heads.
        h_featmap: The height of the feature map.
        w_featmap: The width of the feature map.
        h_original: The height of the original image.
        w_original: The width of the original image.

    Returns:
        The processed attention weights.
    """
    # We only keep the output patch attention
    attention = tf.expand_dims(attention_vals, axis=0)
    attention = tf.reshape(attention[0, :, 0, 1:], (n_heads, w_featmap, h_featmap, 1))
    # Aggregation of the n heads in the last layer
    attention = tf.reduce_mean(attention, axis=0)
    attention = tf.image.resize(
        attention, size=[h_original, w_original], method="lanczos5"
    )
    # Normalize the attention values to have values from zero to one
    attention -= tf.reduce_min(attention)
    attention /= tf.reduce_max(attention)
    return attention


def get_attention(
    attentions: tf.Tensor,
    examples: int,
    num_attention_heads: int,
    h_featmap: int,
    w_featmap: int,
    h_original: int,
    w_original: int,
):
    """
    Get the attention weights for a batch of images.

    Args:
        attentions: The attention weights.
        examples: The number of examples.
        num_attention_heads: The number of attention heads.
        h_featmap: The height of the feature map.
        w_featmap: The width of the feature map.
        h_original: The height of the original image.
        w_original: The width of the original image.

    Returns:
        The attention weights.
    """
    attentions = tf.reshape(attentions, (examples, num_attention_heads, -1))
    last_dimension = int(tf.math.sqrt(float(attentions.shape[-1])).numpy())
    attentions = tf.reshape(
        attentions, (examples, num_attention_heads, last_dimension, last_dimension)
    )

    attention_list = []
    for attention in attentions:
        processed_attenttion = process_attention(
            attention, num_attention_heads, h_featmap, w_featmap, h_original, w_original
        )
        attention_list.append(processed_attenttion)
    return attention_list

In [142]:
def get_image_preprocessor():
    image_preprocessor = Sequential(
        [
            layers.Resizing(
                height=IMG_SIZE[0],
                width=IMG_SIZE[1],
                interpolation="nearest",
            ),
            layers.Rescaling(scale=1.0 / 255.0),
        ]
    )
    return image_preprocessor


def get_attention_image(attentions: [tf.Tensor], image: PIL.Image) -> PIL.Image:
    rescaled_image = tf.cast(image, dtype=tf.float32) / 255.0
    attention_image = rescaled_image * attentions[0]
    attention_image = tf.image.rgb_to_grayscale(attention_image)
    attention_image = tf.squeeze(attention_image, axis=-1)
    attention_image = plt.cm.viridis(attention_image)
    attention_image = PIL.Image.fromarray(np.uint8(attention_image * 255))
    return attention_image


def make_prediction(image: PIL.Image):
    """
    Make a single prediction using the given model and image.

    Args:
        model_path: The path to the model to load.
        image: The image to predict.
        n_images: The number of images to predict.

    Returns:
        The predictions and attention vectors.
    """
    # Load the model
    model = from_pretrained_keras(MODEL_CHECKPOINT)
    model.compile()
    model_config = model.get_layer("tf_vi_t_model").get_config()
    w_featmap = IMG_SIZE[0] // model_config["patch_size"]
    h_featmap = IMG_SIZE[1] // model_config["patch_size"]

    # Convert images to tensorflow Dataset
    image = image.convert("RGB")
    permutation = lambda image: tf.transpose(image, perm=[2, 0, 1])
    image_preprocessor = get_image_preprocessor()
    image_tf = permutation(image_preprocessor(image))
    image_shape = tf.constant(image).shape
    image_ds = tf.data.Dataset.from_tensors(image_tf).batch(1)

    # Make predictions
    model_output = model.predict(image_ds, verbose=0)

    predictions = model_output[0]
    predictions = [float(prediction) for prediction in predictions]
    predicted_classes = [1 if pred > THRESHOLD else 0 for pred in predictions]

    # Obtain the attention vector
    attentions = get_attention(
        model_output[1],
        1,
        model_config["num_attention_heads"],
        h_featmap,
        w_featmap,
        image_shape[0],
        image_shape[1],
    )

    # Get the attention image
    attention_image = get_attention_image(attentions, image)

    return (attention_image, {IMG_CLASSES[predicted_classes[0]]: predictions[0]})

In [162]:
KAGGLE_NOTEBOOK = "[![Static Badge](https://img.shields.io/badge/Open_Notebook_in_Kaggle-blue?logo=kaggle&logoColor=white&labelColor=gray)](https://www.kaggle.com/code/mmenendezg/pneumonia-classifier-using-vit)"
GITHUB_REPOSITORY = "[![Static Badge](https://img.shields.io/badge/Git_Repository-purple?logo=github&logoColor=white&labelColor=gray)](https://github.com/mmenendezg/pneumonia_x_ray)"

demo = gr.Blocks()

with demo:
    gr.Markdown(
    f"""
    # Pneumonia Classifier

    This is a space to test the Pneumonia Classifier model.

    {KAGGLE_NOTEBOOK}

    {GITHUB_REPOSITORY}
    """
    )
    with gr.Row():
        with gr.Column():
            uploaded_image = gr.Image(
                label="Chest X-ray image",
                sources=["upload", "clipboard"],
                type="pil",
                height=550,
            )
        with gr.Column():
            labels = gr.Label(label="Prediction")
            attention_image = gr.Image(
                label="Attention zones", image_mode="L", height=425
            )
    with gr.Row():
        classify_btn = gr.Button("Classify", variant="primary")
        clear_btn = gr.ClearButton(components=[uploaded_image, labels, attention_image])
    classify_btn.click(
        fn=make_prediction, inputs=uploaded_image, outputs=[attention_image, labels]
    )
demo.launch(debug=True, inline=False)

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
Keyboard interruption in main thread... closing server.


