<a href="https://colab.research.google.com/github/futartup/S8-assignment/blob/master/resnet_gradcam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:

import torch
import matplotlib.pyplot as plt
from torch.autograd import Function
from torchvision import models
from torchvision import utils
import cv2
import sys
from collections import OrderedDict
import numpy as np
import argparse
import os
import torch.nn as nn

In [0]:
class ResnetGradCam():
  def __init__(self):
    # get the pretrained resnet18 model
    self.resnet18 = torchvision.models.resnet18(pretrained=True)
    
    # dissect the network to access its last convolutional layer
    self.last_conv_layer = self.resnet18.layer4
    
    # add the average pool 
    self.global_avg_pool = nn.AvgPool2d(kernel_size=3, stride=1)

    # get the classifier of resnet18
    self.classifier = self.resnet18.fc

    # placeholder for gradients
    self.gradients = None

  def activations_hook(self, grad):
    self.gradients = grad

  def forward(self, x):
    x = self.last_conv_layer(x)

    # register the hook
    h = x.register_hook(self.activations_hook)

    x = self.global_avg_pool(x)
    x = x.view((1, 1000))
    x = self.classifier(x)
    return x

  def get_activations_gradient(self):
    return self.gradients

  def get_activations(self, x):
    return self.last_conv_layer(x)




In [0]:
def resnet_gradcam_cifar10(train_loader):
  resnet_obj = ResnetGradCam()
  resnet_obj.eval()
  img, _ = next(iter(train_loader))

  pred = resnet_obj(img).argmax(dim=1)

  pred[:, 386].backward()

  # pull the gradients out of the model
  gradients = resnet_obj.get_activations_gradient()

  # pool the gradients across the channels
  pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

  # get the activations of the last convolutional layer
  activations = resnet_obj.get_activations(img).detach()

  # weight the channels by corresponding gradients
  for i in range(512):
      activations[:, i, :, :] *= pooled_gradients[i]
      
  # average the channels of the activations
  heatmap = torch.mean(activations, dim=1).squeeze()

  # relu on top of the heatmap
  heatmap = np.maximum(heatmap, 0)

  # normalize the heatmap
  heatmap /= torch.max(heatmap)

  # draw the heatmap
  plt.matshow(heatmap.squeeze())

  
  img = cv2.imread(img)
  heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
  heatmap = np.uint8(255 * heatmap)
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  superimposed_img = heatmap * 0.4 + img
  cv2.imwrite('./map.jpg', superimposed_img)

In [0]:
import torch
import torch.nn.functional as F


class GradCAM:
    """Calculate GradCAM salinecy map.
    Args:
        input: Input image with shape of (1, 3, H, W)
        class_idx: Class index for calculating GradCAM.
            If not specified, the class index that makes the highest model prediction score will be used.
    Returns:
        mask: Saliency map of the same spatial dimension with input
        logit: Model output
    """

    def __init__(self, model, layer_name):
        self.model = model
        self.layer_name = layer_name
        self._target_layer()

        self.gradients = dict()
        self.activations = dict()

        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0]

        def forward_hook(module, input, output):
            self.activations['value'] = output

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)
    
    def _target_layer(self):
        layer_num = int(self.layer_name.lstrip('layer'))
        if layer_num == 1:
            self.target_layer = self.model.layer1
        elif layer_num == 2:
            self.target_layer = self.model.layer2
        elif layer_num == 3:
            self.target_layer = self.model.layer3
        elif layer_num == 4:
            self.target_layer = self.model.layer4

    def saliency_map_size(self, *input_size):
        device = next(self.model.parameters()).device
        self.model(torch.zeros(1, 3, *input_size, device=device))
        return self.activations['value'].shape[2:]

    def forward(self, input, class_idx=None, retain_graph=False):
        b, c, h, w = input.size()

        logit = self.model(input)
        if class_idx is None:
            score = logit[:, logit.max(1)[-1]].squeeze()
        else:
            score = logit[:, class_idx].squeeze()

        self.model.zero_grad()
        score.backward(retain_graph=retain_graph)
        gradients = self.gradients['value']
        activations = self.activations['value']
        b, k, u, v = gradients.size()

        alpha = gradients.view(b, k, -1).mean(2)
        # alpha = F.relu(gradients.view(b, k, -1)).mean(2)
        weights = alpha.view(b, k, 1, 1)

        saliency_map = (weights*activations).sum(1, keepdim=True)
        saliency_map = F.relu(saliency_map)
        saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
        saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
        saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

        return saliency_map, logit

    def __call__(self, input, class_idx=None, retain_graph=False):
        return self.forward(input, class_idx, retain_graph)