# Visualization using NiFTi files

This Jupyter Notebook file demonstrates the visualization of the saliency maps of an CMR, given a pretrained binary quality classifier.

In [1]:
from __future__ import print_function

%load_ext autoreload
%autoreload 2

import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import matplotlib.cm as mpl_color_map
from PIL import Image
import nibabel as nib

import torch
import torch.nn as nn
torch.set_printoptions(threshold=5000)
import torchvision
import torchvision.transforms as transforms

from ipywidgets import interact
import ipywidgets as widgets

from utils.custom_gradcam import GradCam
from utils.custom_guided_backprop import GuidedBackprop
from utils.custom_guided_gradcam import guided_grad_cam
from utils.misc_functions import (DictAsMember, apply_colormap_on_image,
                                  get_positive_negative_saliency, convert_to_grayscale)
from models.models import get_model

It is necessary to provide the name of the model, the path for the weights, and the path for the NIfTI image file. For instance, here, `IMAGE` corresponds to a sample testing image with the LVOT region while `IMAGE2` does not.

In [2]:
# MODEL_NAME = "alexnet"
MODEL_NAME = "resnet50"
MODEL_PARAMS_PATH = "resnet50_model.pth"
IMAGE = "example_data/patient019_4d.nii.gz" # Good Quality Image
IMAGE2 = "example_data/XXXXX" # Image with Motion Artefacts

device = torch.device('cuda:0')
n_classes = 2  # Good Quality vs Img w/ Motion Artefact
MODEL_ARGS = {"name": MODEL_NAME,
             "last_layer": "layer4",
             "last_block": 2}
MODEL_ARGS = DictAsMember(MODEL_ARGS)

Images are opened using the loader function in `nibabel` package. Normalization and dtype conversion are also applied at this step. The argument of NIfTI Image loader can be changed from `IMAGE` to `IMAGE2` in order to continue testing with an image with motion artefacts.

In [3]:
proxy_img = nib.load(IMAGE)
img = proxy_img.get_fdata()

img = (img / img.max() * 255).astype(np.uint8).squeeze().transpose(1, 0, 2, 3)
img = img[:, :, :, :]
print("Shape of the image: {}".format(img.shape))

Shape of the image: (256, 216, 11, 30)


4-D Visualization of the image can be viewed below.

In [4]:
def update(d=0, t=0, **kwargs):
    d = int(d)
    t = int(t)
    img_rgb = img[:, :, d, t]
    plt.subplots(figsize=(5, 5))
    io.imshow(img_rgb)

def player(image):
    play = widgets.Play(
        value=0,
        min=0,
        max=img.shape[3]-1,
        step=1,
        interval=100,
        description="Press play",
        disabled=False
    )

    depth_slider = widgets.FloatSlider(min=0, max=image.shape[2]-1, step=1)
    time_slider = widgets.FloatSlider(min=0, max=image.shape[3]-1, step=1)

    widgets.jslink((play, 'value'), (time_slider, 'value'))
    widgets.HBox([play, time_slider])

    interact(update,
             d=depth_slider,
             t2=play,
             t=time_slider,
            )

In [5]:
player(img)

interactive(children=(FloatSlider(value=0.0, description='d', max=10.0, step=1.0), FloatSlider(value=0.0, desc…

Defining the transformations prior processing the image.

In [6]:
input_size = (224, 224)
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
data_transforms = transforms.Compose([transforms.Resize(input_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=img_mean,
                                                           std=img_std)])

Initializing the model and loading its parameters.

In [7]:
model = get_model(MODEL_NAME,
                  device,
                  pretrained=False,
                  n_classes=n_classes)
model.load_state_dict(torch.load(MODEL_PARAMS_PATH, map_location=device)["model"])
model.eval()
print("Model loaded.")

Model loaded.


Creating a loop over depth and phase axes in order to process the images slice-by-slice. During the loop, we traverse over these image slices and collect the predictions, GradCAM maps, Guided Backpropagation maps and Guided GradCAM maps.

In [8]:
preds = []
probs = []

resized_image = np.zeros((224, 224, *img.shape[2:]))
gradcam_map = np.zeros((224, 224, *img.shape[2:]))
pos_gbp_map = np.zeros((224, 224, *img.shape[2:]))
neg_gbp_map = np.zeros((224, 224, *img.shape[2:]))
ggradcam_map = np.zeros((224, 224, *img.shape[2:]))

gcv2 = GradCam(model, MODEL_ARGS)
gbp = GuidedBackprop(model, MODEL_NAME)
softmax = nn.Softmax(dim=1)
color_map = mpl_color_map.get_cmap("hsv")

# with torch.no_grad():
for d in range(img.shape[2]):
    for t in range(img.shape[3]):
        model.zero_grad()
        
        # Convert BW to Pseudo-RGB Image
        slice = Image.fromarray(img[:, :, None, d, t].repeat(3, axis=2))
        resized_image[:, :, d, t] = slice.resize((224, 224)).convert("L")

        # Conversion to PIL Image and applying PyTorch transforms
#         proc_img = Image.fromarray(slice)
        proc_img = data_transforms(slice)
        proc_img = torch.unsqueeze(proc_img, dim=0).to(device, dtype=torch.float32)
        proc_img.requires_grad = True

        # Forward Pass
        output = model(proc_img)
        prob = softmax(output.detach())
        _, pred = torch.max(output.detach(), 1)

        slice = slice.resize((224, 224))

        # GradCAM results
        cam, _ = gcv2.generate_cam(proc_img, torch.argmax(pred))
        cam_p, cam_on_image = apply_colormap_on_image(slice, cam, 'hsv')
        gradcam_map[:, :, d, t] = cam_p.convert("L")   
        
        # Guided BackPropagation Results
        guided_grads = gbp.generate_gradients(proc_img, torch.argmax(pred))
        psal, nsal = get_positive_negative_saliency(guided_grads)
        psal = convert_to_grayscale(psal)
        nsal = convert_to_grayscale(nsal)
        psal = psal - psal.min()
        psal /= (psal.max() - psal.min())
        nsal = nsal - nsal.min()
        nsal /= (nsal.max() - nsal.min())
        pos_gbp_map[:, :, d, t] = psal
        neg_gbp_map[:, :, d, t] = nsal
        
        # Guided Grad-CAM Results
        guided_gradcams = guided_grad_cam(cam, guided_grads)
        bw_guided_gradcams = convert_to_grayscale(guided_gradcams)
        ggradcam_map[:, :, d, t] = bw_guided_gradcams
        
        preds.append(pred)
        probs.append(prob)


As we obtained all of our saliency maps of interests, we can now write a new update function to visualize their change over time.

In [9]:
def update(d=0, t=0, **kwargs):
    d = int(d)
    t = int(t)
    img_rgb = resized_image[:, :, d, t]
    gradcam_rgb = gradcam_map[:, :, d, t]
    pos_gbp_rgb = pos_gbp_map[:, :, d, t]
    neg_gbp_rgb = neg_gbp_map[:, :, d, t]
    ggradcam_rgb = ggradcam_map[:, :, d, t]
    
    plt.subplots(1, 5, figsize=(20, 5))
    plt.subplot(1, 4, 1)
    io.imshow(img_rgb, cmap="gray")
    plt.subplot(1, 4, 2)
    io.imshow(gradcam_rgb, cmap="Reds")
    plt.subplot(1, 4, 3)
    io.imshow(pos_gbp_rgb, cmap="Reds")
    plt.subplot(1, 4, 4)
    io.imshow(ggradcam_rgb, cmap="Reds")

In [10]:
player(gradcam_map)

interactive(children=(FloatSlider(value=0.0, description='d', max=10.0, step=1.0), FloatSlider(value=0.0, desc…

Lastly, we can obtain the overall prediction by averaging the predictions of all slices. A value close to $1$ demonstrate that there is no motion artefact, while a value close to $0$ show that the image volume consists of motion artefacts.

In [11]:
print("Average prediction: {}".format(sum(preds).float()/len(preds)))

Average prediction: tensor([1.0000], device='cuda:0')
