Exploration of RISE with mnist binary

Function : Exploration of RISE mnist binary
Author : Team DIANNA
Contributor :
First Built : 2021.08.25
Last Update : 2021.08.25
Note : We ran the method using the our own trained model on mnist and various instances from mnist dataset. Results look random. There is no sense that we can make of the heatmaps.

In [1]:
import os
import dianna
import onnx
import onnxruntime
import numpy as np
%matplotlib inline
from matplotlib import pyplot as plt
from torchvision import datasets, transforms

In [5]:
# load data
dataset_root = os.path.expanduser('./leafsnap-dataset-30subset/')
img_size = 128
# apply same transform as during training: resize and crop to a square image, then convert to tensor
transform = transforms.Compose([transforms.Resize(img_size),
                                transforms.CenterCrop(img_size),
                                transforms.ToTensor()])

test_data = datasets.ImageFolder(os.path.join(dataset_root, 'dataset/split/test'), transform=transform,
                                 is_valid_file=lambda fname: fname.endswith('.jpg'))
nsample = len(test_data)
nspecies = len(test_data.classes)
print(f'Number of samples: {nsample}')
print(f'Number of species: {nspecies}')

X_test = np.array([instance[0].numpy() for instance in test_data])
X_test = np.transpose(X_test, (0, 2, 3, 1))
y_test = np.array([instance[1] for instance in test_data])

Number of samples: 739
Number of species: 30


UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='./leafsnap-dataset-30subset/dataset/split/test\\abies_concolor\\12995309740717.jpg'>

In [None]:
X_test.shape
plt.imshow(X_test[0])

# Predict classes for test data

In [None]:
from scipy.special import softmax

def run_model(data):
    data = np.transpose(data, (0, 3, 1, 2)).astype(np.float32)
    fname = os.path.expanduser('~/surfdrive/Shared/datasets/leafsnap/leafsnap_model.onnx')
    # get ONNX predictions
    sess = onnxruntime.InferenceSession(fname)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name

    
    onnx_input = {input_name: data}
    pred_onnx = sess.run([output_name], onnx_input)
    
    return softmax(pred_onnx[0], axis=1)


pred_onnx = run_model(X_test)

Print class and image of a single instance in the test data

In [None]:
i_instance = 50
target_class = y_test[i_instance]
print(target_class)
print(pred_onnx[i_instance])
plt.imshow(X_test[i_instance])  # 0 for channel

In [None]:
# heatmaps = dianna.explain(run_model, X_test[[i_instance]], method="RISE", n_masks=2000, feature_res=8)

In [None]:
from dianna.methods import RISE
explainer = RISE(n_masks=2000, feature_res=16)
heatmaps = explainer(run_model, X_test[[i_instance]])

In [None]:
from dianna import visualization
visualization.plot_image(heatmaps[target_class], X_test[i_instance], heatmap_cmap='bwr')
visualization.plot_image(heatmaps[target_class], heatmap_cmap='gray')

In [None]:
def describe(arr):
    print('shape:',arr.shape, 'min:',np.min(arr), 'max:',np.max(arr), 'std:',np.std(arr))

describe(heatmaps[target_class])
# describe(heatmaps[1])

In [None]:
for i in range(10):
    plt.imshow(explainer.masks[i])
    plt.show()