## Test SHAP using binary MNIST
**Function        : Test SHAP using binary MNIST **<br>
**Author          : Team DIANNA **<br>
**Contributor     : **<br>
**First Built     : 2021.06.23 **<br>
**Last Update     : 2021.06.24 **<br>
**Library         : os, numpy, matplotlib, torch, tensorflow, wandb **<br>
**Description     : In this notebook we test XAI method SHAP using trained binary MNIST model.**<br>
**Return Values   : Shapley scores**<br>
**Note**          : We use Captum library to perform SHAP.<br>

In [1]:
%matplotlib inline
import os
import time as tt
import numpy as np
# DL framework
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torch.onnx
# XAI framework
from captum.attr import GradientShap
from captum.attr import IntegratedGradients
from captum.attr import Occlusion
from captum.attr import visualization as viz
# for plotting
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
# report and monitoring with Weights & Biases
#import wandb
# ONNX model and runtime
import onnx
import onnxruntime as ort

### Path to the dataset and the model

In [2]:
# please specify data path
datapath = '/mnt/d/NLeSC/DIANNA/data/mnist/binary-MNIST'
# please specify model path
model_path = '/mnt/d/NLeSC/DIANNA/codebase/dianna/example_data/model_generation/MNIST'
# please specify output path
output_path = '/mnt/d/NLeSC/DIANNA/codebase/dianna/example_data/xai_method_study'
if not os.path.exists(output_path):
    os.makedirs(output_path, exist_ok = True)

### Load data (binary MNIST)

In [3]:
# load binary MNIST from local
# load data
fd = np.load(os.path.join(datapath, 'binary-mnist.npz'))
# training set
train_X = fd['X_train']
train_y = fd['y_train']
# testing set
test_X = fd['X_test']
test_y = fd['y_test']
fd.close()

# dimensions of data
print("dimensions of mnist:")
print("dimensions or training set", train_X.shape)
print("dimensions or training set label", train_y.shape)
print("dimensions or testing set", test_X.shape)
print("dimensions or testing set label", test_y.shape)
# statistics of training set
print("statistics of training set:")
print("Digits: 0 1")
print("labels: {}".format(np.unique(train_y)))
print("Class distribution: {}".format(np.bincount(train_y)))
print("Labels of training set", train_y[:20])

dimensions of mnist:
dimensions or training set (12665, 784)
dimensions or training set label (12665,)
dimensions or testing set (2115, 784)
dimensions or testing set label (2115,)
statistics of training set:
Digits: 0 1
labels: [0 1]
Class distribution: [5923 6742]
Labels of training set [0 1 1 1 1 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1]


In [4]:
# use pytorch data loader
test_X_torch = torch.from_numpy(test_X).type(torch.FloatTensor)
test_y_torch = torch.from_numpy(test_y).type(torch.LongTensor)
# reshape the input following the definition in pytorch (batch, channel, Height, Width)
test_X_torch = test_X_torch.view(-1,1,28,28)

### Load model (ONNX model trained for binary MNIST)

In [5]:
# verify the ONNX model is valid
onnx_file = os.path.join(model_path, 'mnist_model.onnx')
onnx_model = onnx.load(onnx_file)
# check that the IR is well formed
onnx.checker.check_model(onnx_model)
# print a human readable representation of the graph
print('Model :\n\n{}'.format(onnx.helper.printable_graph(onnx_model.graph)))

Model :

graph torch-jit-export (
  %input[FLOAT, batch_sizex1x28x28]
) initializers (
  %layer1.0.weight[FLOAT, 16x1x5x5]
  %layer1.0.bias[FLOAT, 16]
  %layer2.0.weight[FLOAT, 32x16x5x5]
  %layer2.0.bias[FLOAT, 32]
  %fc1.weight[FLOAT, 32x1568]
  %fc1.bias[FLOAT, 32]
  %fc2.weight[FLOAT, 10x32]
  %fc2.bias[FLOAT, 10]
  %26[INT64, 1]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [2, 2, 2, 2], strides = [1, 1]](%input, %layer1.0.weight, %layer1.0.bias)
  %10 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%9)
  %11 = Relu(%10)
  %12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [2, 2, 2, 2], strides = [1, 1]](%11, %layer2.0.weight, %layer2.0.bias)
  %13 = MaxPool[ceil_mode = 0, kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%12)
  %14 = Relu(%13)
  %15 = Shape(%14)
  %16 = Constant[value = <Scalar Tensor []>]()
  %17 = Gather[axis = 0](%15, %16)
  %19 = Unsqueeze[axes = [0]](%17)
 

### Predict the class of the input image <br>
About how to use ONNX model: https://pytorch.org/docs/stable/onnx.html <br>

In [20]:
# get ONNX predictions
sess = ort.InferenceSession(onnx_file)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

onnx_input = {input_name: test_X_torch[:1,:,:,:].detach().numpy().astype(np.float32)}
pred_onnx = sess.run([output_name], onnx_input)[0]
predicted = np.argmax(pred_onnx,1)
print("prediction", predicted)
print("ground truth", test_y[0])

prediction [1]
ground truth 1


### Gradient-based attribution <br>
Compute attributions using Integrated Gradients and visualize them on the image.

In [None]:
# captum lib cannot work on the onnx model
#integrated_gradients = IntegratedGradients(onnx_model)
#attributions_ig = integrated_gradients.attribute(test_X_torch[:1,:,:,:], target=test_y_torch[0], n_steps=100)