## DIANNA demo with MNIST

This notebook showcases the use of DIANNA on a subset of the MNIST
dataset. MNIST contains handwritten digits from 0 to 10. Here, we only
use 0 and 1. A binary classifier is then trained and examined with DIANNA.

#### Install and import packages

In [None]:
!pip install git+https://github.com/dianna-ai/dianna.git

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import onnxruntime as ort
from scipy.special import softmax

np.random.seed(42)

#### Load binary MNIST dataset

In [None]:
data = np.load('data/binary-mnist-data.npz')
X = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1])
y = data['y_test']

In [None]:
# Select a few random samples
selection = np.random.choice(len(X), 9, replace=False)
X_examples = X[selection]
y_examples = y[selection]

# Visualize the selected data
fig, axes = plt.subplots(3, 3, figsize=(6, 6))
for idx, ax in enumerate(axes.flatten()):
    ax.imshow(X[selection[idx]], cmap='gray')
    ax.set_title(f'label: {y[selection[idx]]}')
    ax.axis('off')

#### Load pre-trained model

DIANNA includes a tool to load an ONNX-format model.

In [None]:
from dianna.utils.onnx_runner import SimpleModelRunner

In [None]:
# create a class so we can get predictions out with a single call without having to reload the file every time.
class Runner:
    def __init__(self):
        self.model_runner = SimpleModelRunner('models/binary-mnist-model.onnx')

    def __call__(self, input_data):
        input_data = input_data.reshape(-1, 1, 28, 28)
        output = self.model_runner(input_data)
        return softmax(output)
    
runner = Runner()

y_pred = np.argmax(runner(X_examples), axis=1)
print(y_pred)

Let's visualize the data again and add the predicted labels

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(6, 6))
for idx, ax in enumerate(axes.flatten()):
    ax.imshow(X[selection[idx]], cmap='gray')
    ax.set_title(f'label: {y[selection[idx]]} pred: {y_pred[idx]}')
    ax.axis('off')

The model is 100% correct on these 9 examples. Now we can get to the actual explainable AI part

## Explainable AI with DIANNA

In [None]:
import dianna

Select one image to use DIANNA on

In [None]:
image = X_examples[:1]

The simplest way to use DIANNA is with `dianna.explain_image` or `dianna.explain_text`.  
Here we use the RISE method, which explains the input image by masking random parts of it and
then checking how the output of the model changes.

In [None]:
explanation = dianna.explain_image(runner, image, method='RISE',
                                   n_masks=5000, p_keep=.1)

For each class (two in this case), an explanation with a shape equal to the input image is generated.

In [None]:
print(explanation.shape)

The explanations can be considered images and can be plotted with e.g. `matplotlib`.  
DIANNA also includes visualization tools.

In [None]:
from dianna.visualization import plot_image

In [None]:
# Explanation for class 0
plot_image(explanation[0], heatmap_cmap='bwr', show_plot=False)
plt.title('Explanation for class 0')
plot_image(explanation[0], original_data=image[0], heatmap_cmap='bwr', data_cmap='gray', show_plot=False)
plt.title('Explanation for class 0 with original image');

Note that in these plots, red means important and blue means unimportant. We see that the left, bottom, and right side of the zero are most in favour of classifying the image as a zero, while the top part is least in favour of it.

We can also look at the heatmap for class 1.

In [None]:
# Explanation for class 1
plot_image(explanation[1], heatmap_cmap='bwr', show_plot=False)
plt.title('Explanation for class 0')
plot_image(explanation[1], original_data=image[0], heatmap_cmap='bwr', data_cmap='gray', show_plot=False)
plt.title('Explanation for class 0 with original image');

This looks very similar to the explanation for class 0, but with inverted colours as expected: any pixel in favour of zero is against one, and vice-versa.