This first block of text simply contains all the libraries that need importing as well as initialsing some variables of this example. You may change the batch_size and epochs variables, but should not change num_classes, img_rows and img_cols as these depend on the dataset being used

In [None]:
#Only if needed uncomment the next line
#!pip install numpy keras tf-explain matplotlib

import numpy as np
import keras
from keras import layers
from tf_explain.core.occlusion_sensitivity import OcclusionSensitivity
import matplotlib.pyplot as plt


batch_size = 128
epochs = 15
num_classes = 10
input_shape = (28, 28, 1)

The next block of text handles all the data loading and reshaping so that it can be used to train and evaluate the CNN models. Two sets of data are loaded, the training data, used to generate the model, and the test data, used to evaluate if this model is good at making predictions

In [None]:
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

This block of code defines the layers of the CNN model, compiles it and shows a textual summary

In [None]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        #layers.BatchNormalization(),
        #layers.Dropout(0.25),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        #layers.BatchNormalization(),
        #layers.Dropout(0.25),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        #layers.Dense(128,activation="relu"),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])


This is the block of code that trains the model and evaluates its predictive capacity on the test data

In [None]:
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

This block of code picks an image from the test set (you can change the index to choose a different image), and shows the probabilities that the CNN model estimates for each class (i.e. digit), then shows the actual image.

In [None]:
img_index = 42
image = x_test[img_index]

pred=model.predict(np.expand_dims(image, axis=0))[0]
for digit in range(10):
        print("Probability for digit {}: {}".format(digit,pred[digit]))
print("\nThe winner is {}".format(np.argmax(pred)))
print("The correct class is {}\n".format(np.argmax(y_test[img_index])))

plt.imshow(image.squeeze(),cmap='gray')
plt.show()

In the next block we are using a specialised library called tf-explain that contains a few advance algorithms to visualise the decision making process of CNNs.

This specific example uses one of the available techniques called "Occlusion sensitivity". If we block parts of the input image using a square of a certain size (4x4 pixels in the example below), would the output of the network for a given class change? If this is done systematically, the occlusion influence for each pixel in the image can be estimated.

In [None]:
img_index = 42
image = x_test[img_index]
data = ([image], None)
explainer = OcclusionSensitivity()

fig, axs = plt.subplots(1,10,figsize=(25,25))
for cl in range(num_classes):
  grid = explainer.explain(data, model, cl, 4)
  axs[cl].imshow(grid)
  axs[cl].axis('off')
  axs[cl].set_title("Class = {}".format(cl))

plt.show()
