## Setup and imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models

from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as v

## Preparing your Data for Training with DataLoaders

In [None]:

class Net(nn.Module): ## create layers as class attributes
    def __init__(self): ## Parameters initialization with __init__() function
        super(Net, self).__init__() ## call the parent constuctor
        self.conv1 = nn.Conv2d(3, 6, 5) ## Appy our first set of conv layers
        self.pool1 = nn.MaxPool2d(2, 2) ## Apply our first set of max pooling layers
        self.pool2 = nn.MaxPool2d(2, 2) ## Apply our second set of maxpooling layers
        self.conv2 = nn.Conv2d(6, 16, 5) ## second set of conv layers
        self.fc1 = nn.Linear(16 * 5 * 5, 120) ##first set of fully conneted layers
        self.fc2 = nn.Linear(120, 84) ## second set of fullly conneted layers
        self.fc3 = nn.Linear(84, 10) ## third set of fully connected layer
        self.relu1 = nn.ReLU() ## Apply RELU activation function
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.relu4 = nn.ReLU()

    def forward(self, x): ## specify how the model handles the data. 
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x

## Model initialization
net = Net()

## Define Loss Function and Optimizer

In [None]:
# Initialize criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

## Train the Model

In [None]:
USE_PRETRAINED_MODEL = True
## If using the pretrained model, load it through the function load_state_dict
if USE_PRETRAINED_MODEL:
    print("Using existing trained model")
    net.load_state_dict(torch.load('models/cifar_torchvision.pt'))
else:
    for epoch in range(5):  # loop over the dataset multiple times

        running_loss = 0.0 ## Resetting running_loss to zero 
        for i, data in enumerate(trainloader, 0): ## restarts the trainloader iterator on each epoch.
            # get the inputs
            inputs, labels = data
            # If you don't reset the gradients to zero before each ##backpropagation run, you'll end up with an accumulation of them. 
            optimizer.zero_grad()

           
            outputs = net(inputs) ## Carry out the forward pass. 
            loss = criterion(outputs, labels)## loss computation
            loss.backward() ## Carry out backpropagation, and estimate ##gradients. 
            optimizer.step() ## Make adjustments to the parameters according ##to the gradients. 

            # print statistics
            running_loss += loss.item() ## Build up the batch loss so that we ##can get an average across the epoch. 
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

    print('Finished Training')
    torch.save(net.state_dict(), 'models/cifar_torchvision.pt')

## Make a grid of images

In [None]:
## Define imwshow function
def imshow(img, transpose = True):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy() ## convert image to numpy
    plt.imshow(np.transpose(npimg, (1, 2, 0))) ## The supplied matrix, npimg, ##has to be transposed into numpy with the values of x,y, and z positioned at ##the indexes 1,2,0 respectively. 
    plt.show()
## iterate through the dataset. Each iteration returns a batch of images and ##labels
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images)) ## Display images with ##torchvision.utils.make_grid() function
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) ## Display labels for ground truth

outputs = net(images) ## outcome prediction for each batch
_, predicted = torch.max(outputs, 1) ## Find the class index that has the ##highest probability and pick that one. 

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] ## Display labels for predicted classes
                              for j in range(4)))

### Display results

In [None]:
ind = 3
input = images[ind].unsqueeze(0) ## adds an additional dimension to the tensor.
input.requires_grad = True

In [None]:
## Set the model in evaluation mode
net.eval()

### Set feature attribution function

In [None]:
def attribute_image_f(algorithm, input, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=labels[ind],
                                              **kwargs
                                             )
    
    return tensor_attributions


### Saliency maps

In [None]:
saliency = Saliency(net)
grads = saliency.attribute(input, target=labels[ind].item())
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))

### Integrated gradients

In [None]:
ig = IntegratedGradients(net)
attrig, delta = attribute_image_f(ig, input, baselines=input * 0, return_convergence_delta=True)
attrig = np.transpose(attrig.squeeze().cpu().detach().numpy(), (1, 2, 0))
print('Approximation delta: ', abs(delta))

In [None]:
ig = IntegratedGradients(net)
nt = NoiseTunnel(ig)
attrig_nt = attribute_image_f(nt, input, baselines=input * 0, nt_type='smoothgrad_sq',
                                      nt_samples=100, stdevs=0.2)
attrig_nt = np.transpose(attrig_nt.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

### Deeplift

In [None]:
dl = DeepLift(net)
attrdl = attribute_image_f(dl, input, baselines=input * 0)
attrdl = np.transpose(attrdl.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

### Visualization of attributes

In [None]:
print('Original Image')
print('Predicted:', classes[predicted[ind]], 
      ' Probability:', torch.max(F.softmax(outputs, 1)).item())

original_image = np.transpose((images[ind].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))

_ = v.visualize_image_attr(None, original_image, 
                      method="original_image", title="Original Image")

_ = v.visualize_image_attr(grads, original_image, method="blended_heat_map", sign="absolute_value",
                          show_colorbar=True, title="Overlayed Gradient Magnitudes")

_ = v.visualize_image_attr(attrig, original_image, method="blended_heat_map",sign="all",
                          show_colorbar=True, title="Overlayed Integrated Gradients")

_ = v.visualize_image_attr(attrig_nt, original_image, method="blended_heat_map", sign="absolute_value", 
                             outlier_perc=10, show_colorbar=True, 
                             title="Overlayed Integrated Gradients \n with SmoothGrad Squared")

_ = v.visualize_image_attr(attrdl, original_image, method="blended_heat_map",sign="all",show_colorbar=True, 
                          title="Overlayed DeepLift")