In [2]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import models
from torchbearer import Trial
import cv2
import torch.nn.functional as F

inv_norm = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))
valset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),]))

valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)

Files already downloaded and verified


In [3]:
class ResNet_CAM(nn.Module):
    def __init__(self, net, layer_k):
        super(ResNet_CAM, self).__init__()
        self.resnet = net
        convs = nn.Sequential(*list(net.children())[:-1])
        self.first_part_conv = convs[:layer_k]
        self.second_part_conv = convs[layer_k:]
        self.linear = nn.Sequential(*list(net.children())[-1:])
        
    def forward(self, x):
        x = self.first_part_conv(x)
        x.register_hook(self.activations_hook)
        x = self.second_part_conv(x)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = x.view((1, -1))
        x = self.linear(x)
        return x
    
    def activations_hook(self, grad):
        self.gradients = grad
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.first_part_conv(x)

In [4]:
# Code inspired by https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82

def superimpose_heatmap(heatmap, img):
    
    resized_heatmap = cv2.resize(heatmap.numpy(), (img.shape[2], img.shape[3]))
    resized_heatmap = np.uint8(255 * resized_heatmap)
    resized_heatmap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)
    superimposed_img = torch.Tensor(cv2.cvtColor(resized_heatmap, cv2.COLOR_BGR2RGB)) * 0.006 + inv_norm(img[0]).permute(1,2,0)
    
    return superimposed_img

def get_grad_cam(net, img):

    net.eval()
    pred = net(img)
    pred[:,pred.argmax(dim=1)].backward()
    gradients = net.get_activations_gradient()
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
    activations = net.get_activations(img).detach()
    for i in range(activations.size(1)):
        activations[:, i, :, :] *= pooled_gradients[i]
    heatmap = torch.mean(activations, dim=1).squeeze()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= torch.max(heatmap)
    
    return torch.Tensor(superimpose_heatmap(heatmap, img).permute(2,0,1))

In [8]:
# Pretrained models will be provided upon de-anonymisation

baseline_net = models.ResNet18()
# baseline_net.load_state_dict(torch.load('trained_models/base.pt')['model'])
fmix_net = models.ResNet18()
# fmix_net.load_state_dict(torch.load('trained_models/fup.pt')['model'])
mixup_net = models.ResNet18()
# mixup_net.load_state_dict(torch.load('trained_models/mix.pt')['model'])
fmix_plus_net = models.ResNet18()
# fmix_plus_net.load_state_dict(torch.load('trained_models/mixfup.pt')['model'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [84]:
layer_k = 4

baseline_cam_net = ResNet_CAM(baseline_net, layer_k)
fmix_cam_net = ResNet_CAM(fmix_net, layer_k)
mixup_cam_net = ResNet_CAM(mixup_net, layer_k)
fmix_plus_cam_net = ResNet_CAM(fmix_plus_net, layer_k)

n_imgs = 10
imgs = torch.Tensor(5,n_imgs,3,32,32)
it = iter(valloader)
for i in range(0,n_imgs):
    img, _ = next(it)
    imgs[0][i] = inv_norm(img[0])
    imgs[1][i] = get_grad_cam(baseline_cam_net,img)
    imgs[2][i] = get_grad_cam(mixup_cam_net,img)
    imgs[3][i] = get_grad_cam(fmix_cam_net,img)
    imgs[4][i] = get_grad_cam(fmix_plus_cam_net,img)

torchvision.utils.save_image(imgs.view(-1, 3, 32, 32), "Grad-CAM_at_layer" + str(layer_k) + ".png",nrow=n_imgs, pad_value=1)