## Task 05.4: Implementation demo for feature visualisation and saliency maps

ITU KSADMAL1KU - Advanced Machine Learning for Computer Science 2023

by Stefan Heinrich, with material by Kevin Murphy.

This notebook was in part co-developed with Mingbo Cai at Uni Tokyo.

All info and static material: https://learnit.itu.dk/course/view.php?id=3022225

-------------------------------------------------------------------------------

In [None]:
# @title #### import dependencies

import torchvision.models as models
import torch
import numpy as np
from torchvision.utils import make_grid
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm # colormap
import random
import ast
import os
import torch.nn.functional as F

In [None]:
random.seed()
cuda_id = random.randint(0,5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
print('Using device', device)

### Introduction: This demo illustrates some of the approaches for visualizing the features learned by neural networks and how they make decisions.

#### We use a 11-layer [VGG](https://arxiv.org/pdf/1409.1556.pdf) model trained on [ImageNet](http://www.image-net.org/) as an example.
Pytorch provides several popular [models](https://pytorch.org/docs/stable/torchvision/models.html) with pre-trained parameters. Feel free to check them out after class.
##### Take a look at the output in the next cell.
**features** is a stack of layers that are applied to image input one after another, to extract more and more abstract features. You can use vgg11.features\[:K\] to extract part of it and apply to your input, this will yield feature after the K-1 layer 

**avgpool** is a pooling layer that will adapt its pooling size according to the input to ensure the output has fixed size (7x7). This is important because its output will be flattened and pass through fully-connected layers. Those fully-connected layers have fixed input dimensions, so we don't want the input to those layers change depending on the size of images.

**classifier** is the final stack of layers that ultimately output a number for each possible category in the training data. Passing these numbers for each image sample through a softmax function will yield a vector with values in \[0, 1\] that sum to 1, offering the predicted probability of the image belonging to each category.

In [None]:
# @title #### Model

vgg11 = models.vgg11(weights='IMAGENET1K_V1')
vgg11.eval() # this indicates that we are using it for evaluation mode (not doing any training of the network)
vgg11.to(device) # this commands moves the parameters onto the GPU so that you can run the model on GPU
print(vgg11)

#### ImageNet dataset: all classes

In [None]:
!wget https://gist.githubusercontent.com/VishDev12/971f2835aa1adf2ad30495a25a45b1dc/raw/45ae5963783565e96698111677347a167f86c094/imagenet1000_clsidx_to_labels.txt
imagenet_classes_file = 'imagenet1000_clsidx_to_labels.txt'
with open(imagenet_classes_file, 'r') as file:
    content = file.read()
    imagenet_classes = ast.literal_eval(content)
print('ImageNet classes:', imagenet_classes)

##### Preprocessing
We define the standard transformation of images that were used when VGG model was trained: the resolution, some center cropping, and normalisation. These transformations will be applied to the input images we want to analyze.

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),    
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225] )])
# This normalization is the default processing of VGG network

transform_no_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()])
# We make another transformation without normalization just to visualize the images

#### Illustration of some sample images of one category (church) to illustrate

In [None]:
!mkdir church
!cd church
!wget -P ./church https://upload.wikimedia.org/wikipedia/commons/thumb/d/da/The_second_Catholic_Church_to_be_built_in_Montana.jpg/640px-The_second_Catholic_Church_to_be_built_in_Montana.jpg
!wget -P ./church https://www.catholiceducation.org/en/images/Churchs/church-3481187_640.jpg
!wget -P ./church https://cdn.pixabay.com/photo/2012/03/02/00/36/silhouette-20787_640.jpg
!wget -P ./church https://cdn.pixabay.com/photo/2016/10/01/14/27/dom-1707664_640.jpg

In [None]:
images = []
images_orig = []
for file in os.listdir('./church'):
    if file.endswith('jpg'):
        
        img = Image.open(r'./church/' + file).convert('RGB')
        images.append(transform(img))
        images_orig.append(transform_no_normalize(img))
        
images = torch.stack(images).to(device)
# This command is important to move the images also to memory on GPU. Otherwise pytorch will complain.

fig = plt.figure(figsize=(16,4))
img_grid = make_grid(images_orig, nrow=4)
# this is a torchvision utility to put images together into a grid for plotting

plt.imshow(img_grid.detach().numpy().transpose(1,2,0))
# the transpose command above re-arranges the order of dimensions of a tensor.
# make_grid by default generates a tensor of channel x height x width. 
# But matplotlib assumes that red, green and blue channels are in the last dimension of a tensor.
plt.show()

In [None]:
# @title ##### Pass the images through the network and see its best guess of the category

logit = vgg11(images)
# By simply passing the image as input, the network 

pred_id = logit.argmax(dim=1) # the class with the highest logit will have the highest probability
print('prdicted category ID:', pred_id)

pred_class = [imagenet_classes[class_id]  for class_id in pred_id.detach().cpu().numpy()]
print('network classification:', pred_class)

class_id = torch.mode(pred_id).values
print('ID of the class:', class_id)
# we will use these class ids below! Also, you may change this manually for task 05.4.d

### **Attribution with Salicence Maps**: one approach of visualizing the locations of images that make the strongest contribution to the network's classification decision, is [Grad-CAM](https://arxiv.org/pdf/1610.02391.pdf), that we implement next

##### Basic ideas:
- For images in the same category, find features in a layer generally important for predicting this category

> This is achieved by first calculating the gradient of the output layer unit for that category against all neurons in the layer to investigate, and average the gradients over space and image samples. If the gradient is positive, it means the stronger such features are, the stronger the model believes the image belongs to this category.

- For each location in the feature map, calculate the overall contribution of the features in that location to the decision, based on the gradient calculated.

> This is performed by multiplying the features in that location with the gradient, take a sum, and truncate any negative sum to 0 (the authors of the paper believe negative sum usually means the content at the location favors other categories)

- Last, we can resample the contribution map to the original resolution of image, and superimpose it on the image to visualize how much each part contribute to the network's decision (where the network looks at to make its decision)


In [None]:
# @title ##### The function extracts features from a layer of interest, calculate the average gradient on features, and the contribution map for further visualization.

def feature_importance(model, images, category, layer):
    features = model.features[:layer+1](images)
    # This pass the image through the network until the chosen layer
    # and return the feature map at that layer. The +1 is because
    # python index starts from 0 and ends at the index below the chosen index
    
    features.retain_grad()
    # Normally, activation inside neurons do not require gradient. (can you guess why?)
    # So when calling backward(), gradient with respect to them are not kept.
    # Therefore, here we explicitly require these features to retain gradient
    
    output_logits = model.classifier(torch.flatten(model.avgpool(model.features[layer+1:](features)), start_dim=1))
    # calculate the output at final layer, which can be converted to probability of each class
    
    torch.mean(output_logits[:,category]).backward()
    # calculate gradient with respect to any parameters (and the features we explicitly request to retain gradient)
    # along which the output logit of the selected category can increase
    
    pooled_feature_gradient = torch.mean(features.grad, dim=[0, 2, 3])
    # average the gradient spatially and over batch to estimate what features
    # can move output positively
    
    heatmap = F.relu(torch.mean(features * pooled_feature_gradient[:, None, None], dim=1))
    # weight features according to importance, and calculate the overall useful features
    # for each location
    
    heatmap = heatmap / torch.amax(heatmap, dim=(1, 2))[:, None, None]
    return heatmap, pooled_feature_gradient, features

In [None]:
# @title ##### Now we analyze the layer right before the final pooling using the sample images. You can also try other layers

layer_investigate = 19

heatmap, pooled_feature_gradient, features = feature_importance(vgg11, images, class_id, layer_investigate)
# Yatta! Now we can visualize both the contribution map in that layer, and blend the contribution map with the original image to see what regions the network relies on to call an image a church

hms = heatmap.detach().cpu()[:,None,:,:]
# heatmap for the three iamages. we show them in one grid at once

plt.imshow(make_grid(hms,nrow=4).numpy().transpose(1,2,0), cmap='jet')
plt.title('heatmap of feature importance in layer {}'.format(layer_investigate))
plt.show()

cmap = cm.get_cmap('jet')
resized_heatmap = transforms.Resize(224)(heatmap.detach())
resized_heatmap_color = cmap(resized_heatmap.cpu().numpy())[:,:,:,:3].transpose(0, 3, 1, 2)

plt.imshow(hms[0,0,:,:], cmap='jet', interpolation='nearest')
plt.show()
plt.imshow(resized_heatmap[0,:,:].cpu().detach().numpy(), cmap='jet', interpolation='nearest')
plt.show()

weight = resized_heatmap[:,None,:,:].cpu().numpy() * 0.8
highlighted_images = torch.stack(images_orig) * (1 - weight) + resized_heatmap_color * weight
img_grid = make_grid(highlighted_images, nrow=4)
plt.figure(figsize=(18,12))
plt.imshow(img_grid.numpy().transpose(1,2,0))
plt.title('heatmap of contribution from original images')
plt.show()


##### Looking into the *pooled_feature_gradient*

So the first key outcome of looking at the gradients was to identify what input aspects maximised the output for the church class. As a second outcome, we can also inspect the feature gradients. The more positive ones would capture something more diagnostic about the churches. In the second step, we follow up on this hypothesis and visualise the features.

In [None]:
plt.bar(range(pooled_feature_gradient.shape[0]), pooled_feature_gradient.cpu().numpy())
plt.show()
# It seems feature '325' is particularly positive and could be the most important for the classification

### **Feature Visualisation**: One common approach to analyze neural network is to visualize patterns that maximize responses of certain neurons. We have seen  [this good tutorial](https://distill.pub/2017/feature-visualization/) in the previous exercise.

In short, instead of optimizing parameters as we do when training a network, here we optimize image input to maximize the response of neurons of interest. Here we maximize the features that jointly contributed positively to the network's decision for the example pictures we used.

We first generate image with random color in all pixels. Then we set an objective that tries to increase the values of features that in our sample images contributed positively to the network's classification decision of "church". By back-propagating gradients, we can adjust the random image to improve this objective function. The end result will highlight features at those locations that helped the network's decision.

Notice that in such approach, some regularization is often needed to make the learned image closer to natural image. The regularization we apply here is a smooth prior: adjacent pixels turn to not change colors too much.

You can try to **switch to the other definition of the loss** term in the function *feature_visualize_update* below that is currently commented out. This is adapted from this [paper](https://arxiv.org/pdf/1412.0035v1.pdf). The idea is to find what information is preserved by the features in the layer of interest, by searching for images that generate similar representation to that of the sample image. Here, we additinally weight the representation by the contribution map.

You can also try to **redefine the loss** term in the function *feature_visualize_update* below to implement other methods in the tutorial above.


In [None]:
# @title ##### The function extracts features from a layer of interest, calculate the average gradient on features, and the contribution map for further visualization.

def feature_visualize_update(image, model, heatmap, feature_gradient, target_features, layer, optimizer, reg_weight=0.1):
    optimizer.zero_grad()
    feature = model.features[:layer+1](image)

    loss = - torch.mean(feature * F.relu(target_features.grad * target_features).detach() * heatmap[:, None, :, :])
    # This loss function moves the image along the directions to maximize responses of neurons
    # that contribute positively to the classification output. 
    
    # loss = torch.mean((feature - target_features.detach() * heatmap[:, None, :, :]) ** 2    )
    # the loss function above aims to bring the representation of newly generated images
    # close to the sample images. Try to comment it out and run again. And see what happens if heatmap is not multiplied   
        
    regularizer = torch.mean((image[:,:,1:,:] - image[:,:,:-1,:]) ** 2) + torch.mean((image[:,:,:,1:] - image[:,:,:,:-1]) ** 2)
    
    total_loss = loss + regularizer * reg_weight
    total_loss.backward()
    optimizer.step()
    # This optimization updates the image tensor only.
    
    return image

In [None]:
# @title ##### be aware: the next step will take quite a while!

batch_size = len(images_orig)
feature_vis = torch.randn(batch_size, 3, 224, 224, device=device, requires_grad=True)
# feature_vis = torch.tensor(images, device=device, requires_grad=True)
# You can also try initializing with the sample images

learning_rate = 0.05
optimizer = torch.optim.Adam([feature_vis], lr = learning_rate)

for it in range(800):
    feature_vis = feature_visualize_update(feature_vis, vgg11, heatmap.detach(), pooled_feature_gradient.detach(),
                                           features, layer_investigate, optimizer, reg_weight=0.1)
    if it % 100 == 0:
        fig = plt.figure(figsize=(18,9))
        img_grid = make_grid(feature_vis , normalize=True, nrow=4) # * resized_heatmap[:,None,:,:]
        
        plt.imshow(img_grid.detach().cpu().numpy().transpose(1,2,0))
        plt.show()
        
    # Notice this is not an exact implementation of the Guided Grad-CAM in the original paper.