# GradCAM visualisations

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from torchvision.transforms.functional import normalize, resize, to_pil_image
from skimage.io import imread
from skimage.transform import resize
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
from skimage.io import imread
from skimage.io import imsave

from prediction.chexpert_disease import CheXpertDataModule, ResNet, DenseNet

In [None]:
# Parameters
class_idx = 10  # 10 = Pleural Effusion
img_size = 128
num_classes = 14
model_path = "prediction/chexpert/disease/models/resnet-all_128/version_1/checkpoints/epoch=1-step=1018.ckpt"
model_path = "prediction/chexpert/disease/models/densenet-all_128/version_0/checkpoints/epoch=9-step=5090.ckpt"
model_type = "ResNet"
model_type = "DenseNet"
layer = "model.layer4"
layer = "model.features.denseblock4.denselayer16"
image_paths = [
    "/Users/felixkrones/python_projects/data/ChestXpert/preproc_128x128_len_202/patient64542_study1_view1_frontal.jpg",
    #"/Users/felixkrones/python_projects/data/ChestXpert/preproc_128x128_len_202/patient64543_study1_view1_frontal.jpg"
]
out_dir = ('/').join(model_path.split("/")[:-3] + ["gradcam/"])

In [None]:
# Get model
model = eval(model_type).load_from_checkpoint(model_path, num_classes=num_classes)
model.eval()

In [None]:
# Get data
images = []
for image in image_paths:
    image = imread(image).astype(np.float32)
    image = torch.from_numpy(image).unsqueeze(0)
    if image.shape[2] == 3:
        image = image.permute(2, 0, 1)
    elif image.shape[0] == 3:
        image = image
    elif image.shape[0] == 1:
        image = image.repeat(3, 1, 1)
    else:
        raise ValueError(f"Image shape {image.shape} not supported.")
    images.append(image)

# Convert list of images to tensor batch
images = torch.from_numpy(np.stack(images))
images.shape

In [None]:
# Show last image
plt.imshow(image.permute(1, 2, 0).numpy().astype(np.uint8))

In [None]:
# Run GradCAM
cam_extractor = GradCAM(model, layer)
outs = model(images)
cams = cam_extractor(class_idx, outs)[0]

In [None]:
# Overlayed on the image
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
for cam, img, path in zip(cams, images, image_paths):
  result = overlay_mask(to_pil_image(img.squeeze()), to_pil_image(cam.squeeze(0), mode='F'), alpha=0.5)
  plt.imshow(result); plt.show()
  result.save(out_dir + path.split("/")[-1])