# Exercise 08

## Gradient-based attribution methods for deep neural networks

### Vanilla Gradient (Saliency map)

Use the following code as a basis to calculate a vanilla saliency map.

Hints:

- you can set the class outputs for other classes to zero by using specifying `gradient` in `backward(gradient=)`, or by applying `torch.max` to the predictions if you want to explain the predicted class

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


In [None]:
#load pretrained resnet model
model = torchvision.models.resnet50(pretrained=True)
print(model)

#define transforms to preprocess input image into format expected by model
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])

#inverse transform to get normalize image back to original form for visualization
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)

#transforms to resize image to the size expected by pretrained model,
#convert PIL image to tensor, and
#normalize the image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,          
]);

In [None]:
def saliency(input, model):
    
    #we don't need gradients w.r.t. weights for a trained model
    for param in model.parameters():
        param.requires_grad = False
    
    #set model in eval mode
    model.eval()

    #we want to calculate gradient of highest score w.r.t. input
    #so set requires_grad to True for input 
    input.requires_grad = True
    # TODO

    return saliency_map.numpy()

In [None]:
def plot(input_img, saliencey_map):  
    #plot image and its saleincy map
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(input_img, (1, 2, 0)))
    plt.xticks([])
    plt.yticks([])
    plt.subplot(1, 2, 2)
    plt.imshow(saliencey_map, cmap=plt.cm.hot)
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
img = Image.open('puppy.jpg').convert('RGB')

# transform and normalize image and put it in a stack
img_normalized = transform(img).unsqueeze(0)
# unstack and unnormalize image
img_normalized_inv = inv_normalize(img_normalized[0])

# calculate saliency map
saliencey_map = saliency(img_normalized, model)

# plot image and saliency map
plot(img_normalized_inv, saliencey_map)

### Smooth Grad

**Question**: Implement Smooth grad with at least 10 runs.

In [None]:
#average saliency maps 
# (
# hints: 
# 1.load puppy image 
# 2. add random noise to it 
# 3. calculate saliency map for each noisy image 
# 4. average saliency maps
# )

# TODO

saliencey_map = None

In [None]:
# plot
img = Image.open('puppy.jpg').convert('RGB')
img_normalized = transform(img).unsqueeze(0)
img_normalized_inv = inv_normalize(img_normalized[0])

plot(img_normalized_inv, saliencey_map)