<a href="https://colab.research.google.com/github/futartup/S8-assignment/blob/master/resnet_gradcam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os, sys
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [0]:
class ResnetGradCam(nn.Module):
  def __init__(self, model):
    super(ResnetGradCam, self).__init__()

    # get the pretrained resnet18 model
    self.resnet18 = model
    
    # dissect the network to access its last convolutional layer
    self.features_conv = nn.Sequential(self.resnet18.conv1,
                                           self.resnet18.bn1,
                                           self.resnet18.layer1,
                                           self.resnet18.layer2,
                                           self.resnet18.layer3,
                                           self.resnet18.layer4
                                           )  # list(self.resx.children())[:-5]

    self.linear = self.resnet18.linear

    # placeholder for the gradients
    self.gradients = None

  def activations_hook(self, grad):
    self.gradients = grad

  def forward(self, x):
    x = self.features_conv(x)

    # register the hook
    h = x.register_hook(self.activations_hook)

    x = F.avg_pool2d(x, 4)
    x = x.view(x.size(0), -1)
    x = self.linear(x)
    return x

  def get_activations_gradient(self):
    return self.gradients

  def get_activations(self, x):
    return self.features_conv(x)

In [0]:
def getheatmap(pred, class_pred, netx, img):
    # get the gradient of the output with respect to the parameters of the model
    pred[:, class_pred].backward()
    # pull the gradients out of the model
    gradients = netx.get_activations_gradient()

    # pool the gradients across the channels
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # get the activations of the last convolutional layer
    activations = netx.get_activations(img.cuda()).detach()

    # weight the channels by corresponding gradients
    for i in range(512):
        activations[:, i, :, :] *= pooled_gradients[i]

    # average the channels of the activations
    heatmap = torch.mean(activations, dim=1).squeeze()

    # relu on top of the heatmap
    # expression (2) in https://arxiv.org/pdf/1610.02391.pdf
    heatmap = np.maximum(heatmap.cpu(), 0)

    # normalize the heatmap
    heatmap /= torch.max(heatmap)
    # heatmap = None
    return heatmap


def imshow(img, ax):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.cpu().numpy()
    ax.imshow(np.transpose(npimg, (1, 2, 0)))


def superposeimage(heatmap, img):
    heat1 = np.array(heatmap)
    heatmap1 = cv2.resize(heat1, (img.shape[1], img.shape[0]))
    heatmap1 = np.uint8(255 * heatmap1)
    heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)
    superimposed_img = heatmap1 * 0.4 + img
    cv2.imwrite('./map.jpg', superimposed_img)


def gradcamof(net, imgs, classes):
    netx = ResnetGradCam(net)
    netx.eval()

    for img in imgs:
        fig, axes = plt.subplots(nrows=1, ncols=3)
        # get the most likely prediction of the model
        pred = netx(img.cuda())
        from torchvision.utils import save_image
        imx = img[0]
        save_image(imx, 'img1.png')
        class_pred = int(np.array(pred.cpu().argmax(dim=1)))
        imshow(torchvision.utils.make_grid(img), axes[0])
        print(classes[class_pred])
        # axes.set_title(str(classes[class_pred]))

        # draw the heatmap
        heatmap = getheatmap(pred, class_pred, netx, img)
        axes[1].matshow(heatmap.squeeze())

        imx = cv2.imread("./img1.png")
        imx = cv2.cvtColor(imx, cv2.COLOR_BGR2RGB)
        # plt.imshow(imx, cmap='gray', interpolation='bicubic')
        superposeimage(heatmap, imx)

        imx = cv2.imread('./map.jpg')
        imx = cv2.cvtColor(imx, cv2.COLOR_BGR2RGB)

        # scale_percent = 220  # percent of original size
        # width = int(imx.shape[1] * scale_percent / 100)
        # height = int(imx.shape[0] * scale_percent / 100)
        # dim = (width, height)
        # # resize image
        # imx = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
        axes[2].imshow(imx, cmap='gray', interpolation='bicubic')