In [None]:
# ==============================================================================
# Generate the GradCAM explanation for the given model and image.
# ==============================================================================

In [None]:
import eli5  # https://eli5.readthedocs.io/en/latest/index.html
import tensorflow as tf  # https://www.tensorflow.org
import numpy as np  # https://numpy.org
from PIL import Image  # https://pillow.readthedocs.io/en/stable/
import matplotlib.pyplot as plt  # https://matplotlib.org
import matplotlib.cm

tf.compat.v1.disable_eager_execution()  # Otherwise ELI5 doesn't work.

In [None]:
def load_image(path):
    return Image.open(path)


def load_image_into_array(path):
    """ Load the image into a numpy array (format suitable for the model). """
    image = tf.keras.preprocessing.image.load_img(path)
    input_arr = tf.keras.preprocessing.image.img_to_array(image)
    input_arr = np.expand_dims(input_arr, axis=0)  # Convert to batch

    return input_arr

In [None]:
img_path = "PATH/TO/IMAGE"
img_arr = load_image_into_array(img_path)
img = load_image(img_path)

In [None]:
model = tf.keras.models.load_model("PATH/TO/MODEL")
model.summary()

In [None]:
print(model.predict(img_arr))

In [None]:
exp = eli5.show_prediction(model, img_arr, image=img, layer='conv5_3_1x1_increase/bn', colormap=matplotlib.cm.viridis)
plt.imshow(exp)
cbar = plt.colorbar()
cbar.set_ticks([])
plt.axis('off')
plt.savefig("PATH/TO/SAVE/EXP", dpi=600, bbox_inches='tight')

In [None]:
def plot_probability_distribution(pred, label_true, save_file=None):
    """ Plot the probability distribution as given in pred in a bar plot. """
    class_names = ['AF', 'AN', 'DI', 'HA', 'NE', 'SA', 'SU']
    # Instead of using FE, use AF and thus put it to the front.
    new_pred = [pred[0][2], pred[0][0], pred[0][1], pred[0][3], pred[0][4], pred[0][5], pred[0][6]]
    barlist = plt.bar(class_names, new_pred, color='paleturquoise')
    barlist[label_true].set_color('teal')  # Give the bar of the correct class a different colour.

    plt.yticks(np.arange(0, 1.1, 0.1))
    handle = [plt.Rectangle((0, 0), 1, 1, color='teal'), plt.Rectangle((0, 0), 1, 1, color='paleturquoise')]
    plt.legend(handle, ['True label', 'Other label'])
    plt.title('Probability distribution')

    if save_file:
        plt.savefig(save_file, dpi=600)

    plt.show()

plot_probability_distribution(model.predict(img_arr), 6, "SAVEFILE_PROB_DIST")