<a href="https://colab.research.google.com/github/futartup/S8-assignment/blob/master/GradCam/resnet18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:

try:
  import ipynb
except:
  !pip install ipynb --upgrade
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import os, sys
import cv2
if '/content/drive/My Drive/Colab Notebooks/S8' not in sys.path:
  sys.path.append('/content/drive/My Drive/Colab Notebooks/S8')
from ipynb.fs.full.transform_train_test_loader import *

In [0]:

class ResnetGradCam(nn.Module):
  def __init__(self):
    super(ResnetGradCam, self).__init__()

    # get the pretrained resnet18 model
    self.resnet18 = torchvision.models.resnet18(pretrained=True)
    
    # dissect the network to access its last convolutional layer
    self.last_conv_layer = self.resnet18.layer4
    
    # add the average pool 
    self.global_avg_pool = nn.AvgPool2d(kernel_size=3, stride=1)

    # get the classifier of resnet18
    self.classifier = self.resnet18.fc

    # placeholder for gradients
    self.gradients = None

  def activations_hook(self, grad):
    self.gradients = grad

  def forward(self, x):
    x = self.last_conv_layer(x)

    # register the hook
    h = x.register_hook(self.activations_hook)

    x = self.global_avg_pool(x)
    x = x.view((1, 1000))
    x = self.classifier(x)
    return x

  def get_activations_gradient(self):
    return self.gradients

  def get_activations(self, x):
    return self.last_conv_layer(x)




In [0]:
def resnet_gradcam_cifar10(train_loader, image):
  train_loader = get_train_loader()
  resnet_obj = ResnetGradCam()
  resnet_obj.eval()
  img, _ = next(iter(train_loader))

  pred = resnet_obj(img).argmax(dim=1)

  pred[:, 386].backward()

  # pull the gradients out of the model
  gradients = resnet_obj.get_activations_gradient()

  # pool the gradients across the channels
  pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

  # get the activations of the last convolutional layer
  activations = resnet_obj.get_activations(img).detach()

  # weight the channels by corresponding gradients
  for i in range(512):
      activations[:, i, :, :] *= pooled_gradients[i]
      
  # average the channels of the activations
  heatmap = torch.mean(activations, dim=1).squeeze()

  # relu on top of the heatmap
  # expression (2) in https://arxiv.org/pdf/1610.02391.pdf
  heatmap = np.maximum(heatmap, 0)

  # normalize the heatmap
  heatmap /= torch.max(heatmap)

  # draw the heatmap
  plt.matshow(heatmap.squeeze())

  
  img = cv2.imread(img)
  heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
  heatmap = np.uint8(255 * heatmap)
  heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
  superimposed_img = heatmap * 0.4 + img
  cv2.imwrite('./map.jpg', superimposed_img)