In [1]:
import torch as tc
from torch.utils import data
from torchvision.models import densenet169, densenet201
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class DenseNet(tc.nn.Module):
    def __init__(self):
        super(DenseNet, self).__init__()
        
        # get the pretrained DenseNet169 network
        self.densenet = densenet169(pretrained=True)
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.densenet.features
        
        # add the maxpool2d pool
        #self.max_pool = tc.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        # add the average global pool
        self.global_avg_pool = tc.nn.AvgPool2d(kernel_size=7, stride=1)
        
        # get the classifier of the densenet169
        self.classifier = self.densenet.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        x = self.features_conv(x)
        
        # register the hook
        h = x.register_hook(self.activations_hook)
        
        # don't forget the pooling
        x = self.global_avg_pool(x)
        x = x.view((1, 1664))
        x = self.classifier(x)
        return x
    
    def get_activations_gradient(self):
        return self.gradients
    
    def get_activations(self, x):
        return self.features_conv(x)

In [6]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),
    transforms.ToTensor(),
    normalize
])
# define a 1 image dataset
dataset = datasets.ImageFolder(root=r'DaneTest/', transform=transform)

dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

dense = DenseNet()
dense.eval()
img, _ = next(iter(dataloader))

pred = dense(img).argmax(dim=1)
print(pred)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor([101])


In [7]:
gradients = dense.get_activations_gradient()
print(gradients)
# pool the gradients across the channels
pooled_gradients = tc.mean(gradients, dim=[0, 2, 3])

# get the activations of the last convolutional layer
activations = dense.get_activations(img).detach()

# weight the channels by corresponding gradients
for i in range(512):
    activations[:, i, :, :] *= pooled_gradients[i]
    
# average the channels of the activations
heatmap = tc.mean(activations, dim=1).squeeze()

# relu on top of the heatmap
# expression (2) in https://arxiv.org/pdf/1610.02391.pdf
heatmap = np.maximum(heatmap, 0)

# normalize the heatmap
heatmap /= tc.max(heatmap)

# draw the heatmap
plt.matshow(heatmap.squeeze())

None


TypeError: mean() received an invalid combination of arguments - got (NoneType, dim=list), but expected one of:
 * (Tensor input, *, torch.dtype dtype)
 * (Tensor input, tuple of ints dim, bool keepdim, *, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of names dim, bool keepdim, *, torch.dtype dtype, Tensor out)
