<a href="https://colab.research.google.com/github/moradza/Explanatory-Masking-for-Deep-Learning/blob/main/ExplanatoryMasksforNeuralNetworkInterpretability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Explanatory Deep Learning
Two different methods are studied on MNIST dataset to find important feature in image for deep learning classification. Further details inside following references.

1. Masking [ref](https://arxiv.org/pdf/1911.06876.pdf) 
2. Saliancy Map [ref](https://arxiv.org/abs/1610.02391) 

**future work: CNN and Graph Neural Network**<br> 
*additional documentations in the future*

In [1]:
import numpy as np
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
from tensorflow.keras import layers 
import matplotlib.pyplot as plt


model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10)
])

In [2]:
original_dim = 784

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32")  / 255
x_test = x_test.reshape(x_test.shape[0], 784).astype("float32") 


In [3]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [4]:
model.fit(x_train, y_train, epochs=40)

Epoch 1/40
Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


<tensorflow.python.keras.callbacks.History at 0x7fab87c41590>

In [5]:
test_loss, test_acc = model.evaluate(x_test,  y_test, verbose=2)
print('\nTest accuracy:', test_acc)

313/313 - 0s - loss: 38.1366 - accuracy: 0.9822

Test accuracy: 0.982200026512146


In [6]:
# model.get_config()

In [7]:
layer_output = model.get_layer('dense').output
layer_output.trainable =False
model.trainable = False
explainer_layer1 = Dense(int(original_dim/2), activation='relu')(layer_output)#
mask = Dense(original_dim, activation='sigmoid',activity_regularizer=tf.keras.regularizers.L1L2(l1=0.00001, l2=0.00001))(explainer_layer1)#
new_input = tf.keras.layers.Multiply()([mask, model.input])
class_new = model(new_input, training=False)
model2 = Model(model.input, outputs=class_new)


In [8]:
model2.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model2.fit(x_train, y_train, epochs=30)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30

In [None]:
test_loss, test_acc = model2.evaluate(x_test,  y_test, verbose=2)
print('\nTest accuracy:', test_acc)

In [None]:
# model2.get_config()

In [None]:
layer_output = model2.get_layer('multiply').output
modelexplainer = Model(model2.input, outputs=layer_output)

In [None]:
model.summary(), model2.summary()

In [None]:
masked_img = modelexplainer.predict(x_train)
img = x_train
def plot_image(img):
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(1-img.reshape((28,28)), cmap='gray')
    
num_rows = 8
num_cols = 4
num_images = num_rows*num_cols
plt.figure(figsize=(7*num_cols, 7*num_rows))
for i in range(num_images):
    plt.subplot(2*num_rows, 2*num_cols, 2*i+1)
    plot_image(img[i])
    plt.subplot(2*num_rows, 2*num_cols, 2*i+2)
    plot_image(masked_img[i])
plt.tight_layout()
plt.show()

In [None]:
def saliency(inputs):
    inps = tf.Variable(inputs)
    with tf.GradientTape() as tape:
        prediction = model(inps)
    return tape.gradient(prediction,inps)

In [None]:
saliency(x_train[:3]).shape

In [None]:
saliency_img = saliency(x_train).numpy()
img = x_train

num_rows = 8
num_cols = 4
num_images = num_rows*num_cols
plt.figure(figsize=(5*num_cols, 5*num_rows))
for i in range(num_images):
    plt.subplot(2*num_rows, 2*num_cols, 2*i+1)
    plot_image(img[i])
    plt.subplot(2*num_rows, 2*num_cols, 2*i+2)
    plot_image(saliency_img[i])#/np.max(saliency_img[i]))
plt.tight_layout()
plt.show()

In [None]:
num_rows = 10
num_cols = 2
num_images = num_rows*num_cols
plt.figure(figsize=(7*num_cols, 7*num_rows))
for i in range(num_images):
    plt.subplot(3*num_rows, 3*num_cols, 3*i+1)
    plot_image(saliency_img[i])
    plt.subplot(3*num_rows, 3*num_cols, 3*i+2)
    plot_image(img[i])
    plt.subplot(3*num_rows, 3*num_cols, 3*i+3)
    plot_image(masked_img[i])
plt.tight_layout()
plt.show()

In [None]:
train_loss, train_acc = model2.evaluate(masked_img,  y_train, verbose=2)
print("Accuracy of masked images: %.4f", train_acc)