# GradCAM visualisations

In [1]:
import os
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt

from prediction.chexpert_disease import CheXpertDataModule, DenseNet

In [None]:
output_path = './map.jpg'
class_idx = 0

img_size = 128

model_path = ""

img_data_dir = "/Users/felixkrones/python_projects/data/ChestXpert/"

csv_train_img = f"../datafiles/chexpert/chexpert.sample_{img_size}_from_train_filtered_True.train.csv"
csv_val_img = f"../datafiles/chexpert/chexpert.sample_{img_size}_from_train_filtered_True.val.csv"
csv_test_img = f"../datafiles/chexpert/chexpert.sample_{img_size}_from_train_filtered_True.test.csv"

path_col_test = "path_preproc" # "fake_image_path"

In [None]:
# Get data
data = CheXpertDataModule(
    img_data_dir=img_data_dir,
    csv_train_img=csv_train_img,
    csv_val_img=csv_val_img,
    csv_test_img=csv_test_img,
    image_size=(img_size, img_size),
    pseudo_rgb=True,
    batch_size=1,
    num_workers=4,
    path_col_test=path_col_test,
)

img, _ = next(iter(data.test_dataloader()))

In [None]:
# Get model
model_type = DenseNet
model = model_type(num_classes=14)
model = model_type.load_from_checkpoint(model_path,
            num_classes=num_classes,
        )

In [None]:
# Get predictions
model.eval()
pred = model(img).argmax(dim=1)

In [None]:
# get the gradient of the output with respect to the parameters of the model
pred[:, class_idx].backward()

# pull the gradients out of the model
gradients = model.get_activations_gradient()

# pool the gradients across the channels
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

# get the activations of the last convolutional layer
activations = model.get_activations(img).detach()

# weight the channels by corresponding gradients
for i in range(img_size):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# average the channels of the activations
heatmap = torch.mean(activations, dim=1).squeeze()

# relu on top of the heatmap
# expression (2) in https://arxiv.org/pdf/1610.02391.pdf
heatmap = np.maximum(heatmap, 0)

# normalize the heatmap
heatmap /= torch.max(heatmap)

# draw the heatmap
plt.matshow(heatmap.squeeze())

In [None]:
# Put heatmap on top of picture
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
superimposed_img = heatmap * 0.4 + img
cv2.imwrite(output_path, superimposed_img)
