In [None]:
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------
import sys
import os

# GradCam visualizations

In [None]:
import matplotlib.pyplot as plt
import numpy as np


from typing import Optional, Any
def _plot_one_map(map: np.ndarray,
                  image: np.ndarray,
                  channel: int,
                  slice: int,
                  row: int,
                  col: Optional[int]) -> None:
    """
    Plots one visualization overlaid on one image for one
    specific channel and slice at a chosen subplot.
    """
    if col is not None:
        current_ax = ax[row, col]
    else:
        current_ax = ax[row]
    current_ax.imshow(image[channel, slice], cmap=plt.gray())
    pos = current_ax.imshow(
        map[channel, slice],
        vmin=map.min(),
        vmax=map.max(),
        cmap=plt.jet(),
        alpha=0.7
    )
    current_ax.set_xticks([])
    current_ax.set_yticks([])
    if slice == 0:
        plt.colorbar(pos, plt.jet(), current_ax)
    current_ax.set_anchor('W')

In [None]:
subject_id = "The subject id"
gradcam_dir = "Path/to/gradcam"
has_non_image_features = True
has_image_features = True
probas = ["[0], [0]"]
ground_truth_labels = ["[0], [0]"]
non_image_labels = [""]
encode_jointly = False
imaging_feature_type = "Image"
target_position = "Default"
value_image_and_segmentation = "ImageAndSegmentation"

In [None]:
if has_image_features:
    image = np.load(os.path.join(os.path.join(gradcam_dir, "image.npy")))
    grad_cam = np.load(os.path.join(gradcam_dir, "gradcam.npy"))
    guided_grad_cam = np.load(os.path.join(gradcam_dir, "guided_grad_cam.npy"))
if has_non_image_features:
    gradcam_non_image_features = np.load(os.path.join(gradcam_dir, "non_image_pseudo_cam.npy"))

In [None]:
if has_image_features:
    channels, slices = image.shape[:2]
    if imaging_feature_type == value_image_and_segmentation:
        channels //= 2
else:
    channels = gradcam_non_image_features.shape[0]

if has_non_image_features:
    gradcam_non_image_features = gradcam_non_image_features / gradcam_non_image_features.sum() * 100  

### Model prediction

In [None]:
print(f"Subject ID: {subject_id}")
print(f"Target position: {target_position}")
print(f"Probability predicted by the model {probas}")
print(f"Ground truth label {ground_truth_labels}")

### Plot features importance

In [None]:
if has_non_image_features:
    fig, ax = plt.subplots(figsize=(20,7))
    x_data = np.arange(len(non_image_labels))
    ax.bar(x_data, gradcam_non_image_features.flatten())  # type: ignore
    ax.set_xticks(x_data)
    ax.set_xticklabels(non_image_labels, rotation=90)
    ax.set_title(f"Relative non-imaging feature importance (%)")
else:
    print("This model only uses imaging features")

### GradCam maps

In [None]:
if has_image_features:
    if encode_jointly:
        fig, ax = plt.subplots(slices, 1, figsize=(10, 5*slices))
        for i in range(slices):
            _plot_one_map(grad_cam, image, 0, i, i, None)
    else:
        fig, ax = plt.subplots(slices, channels, figsize=(20, 2*slices))
        for i in range(slices):
            for channel in range(channels):
                _plot_one_map(grad_cam, image, channel, i, i, channel)
else:
    "This model is not using images"

### GuidedGradCam maps

In [None]:
figure_title = "GuidedGrad for images" if imaging_feature_type != value_image_and_segmentation \
                else "GuidedGradCam for segmentations"

if has_image_features:
        fig, ax = plt.subplots(slices, channels, figsize=(20, 2*slices))
        for i in range(slices):
            for channel in range(channels):
                _plot_one_map(guided_grad_cam, image, channel, i, i, channel)
        plt.suptitle(figure_title)
else:
    "This model is not using images"

In [None]:
if imaging_feature_type == value_image_and_segmentation and has_image_features:
        fig, ax = plt.subplots(slices, channels, figsize=(20, 2*slices))
        for i in range(slices):
            for channel in range(channels):
                _plot_one_map(guided_grad_cam, image, channel + channels, i, i, channel)
        plt.suptitle("GuidedGrad for Segmentations (imaging_type is ImageAndSegmentation)")