<img src="https://www.th-koeln.de/img/logo.svg" style="float:right;" width="200">

# Musterlösung / Sample solution 
## 10th exercise: <font color="#C70039">Interpretable Machine Learning with Shapley Values for image classification</font>
* Course: AML
* Lecturer: <a href="https://www.gernotheisenberg.de/">Gernot Heisenberg</a>
* Author of notebook: <a href="https://www.gernotheisenberg.de/">Gernot Heisenberg</a>
* Date:   28.10.2024

---------------------------------

### <font color="ce33ff">DESCRIPTION</font>:
This is one implementation example to demo XAI for image classification using the inbuild cifar-10 data set, that you have come across with in exercise 8 already.

### <font color="FF0000">IMPORTANT NOTE</font>:

This code needs the shap library version 0.44.0 !
Hence, numpy and tensorflow then also need earlier versions to work with this version of shap. These are: numpy==1.26.4 and tensorflow==2.15.0.

## Imports
Import all necessary utilities.

In [None]:
import shap # v0.44.0
import numpy as np # v1.26.4

import tensorflow as tf # v2.15.0
from   tensorflow import keras
import matplotlib.pyplot as plt
from   tensorflow.keras.models import Sequential
import ssl

## load build-in dataset
take the cifar-10 data set from exercise 8

In [None]:
ssl._create_default_https_context = ssl._create_unverified_context
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# reshape and normalize data
x_train = x_train.reshape(50000, 32, 32, 3).astype("float32") / 255
x_test = x_test.reshape(10000, 32, 32, 3).astype("float32") / 255
y_train = y_train.reshape(50000,)
y_test = y_test.reshape(10000,)

### Build a simple CNN, compile and fit the model.

In [None]:
inputs = tf.keras.Input(shape=(32, 32, 3))
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)
x = tf.keras.layers.MaxPooling2D((2, 2))(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(256, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
outputs = tf.keras.layers.Dense(10, activation='softmax')(x)

# inputs and outputs
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="test_for_shap")
# compile the model
model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.Adam(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
  )
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs = 10)

### predict on the test set (one image for each class).

In [None]:
# class label list
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# example image for each class
images_dict = dict()

for i, l in enumerate(y_train):
    if len(images_dict)==10:
        break
    if l not in images_dict.keys():
        images_dict[l] = x_train[i].reshape((32, 32,3))
images_dict = dict(sorted(images_dict.items()))
    
# example image for each class for test set
x_test_dict = dict()
for i, l in enumerate(y_test):
    if len(x_test_dict)==10:
        break
    if l not in x_test_dict.keys():
        x_test_dict[l] = x_test[i]
        
# order by class
x_test_each_class = [x_test_dict[i] for i in sorted(x_test_dict)]
x_test_each_class = np.asarray(x_test_each_class)

# Compute predictions
predictions = model.predict(x_test_each_class)
predicted_class = np.argmax(predictions, axis=1)

### Visualization
#### plot function
define a plot function for actual and predicted class.

In [None]:
# plot actual and predicted class
def plot_actual_predicted(images, pred_classes):
    fig, axes = plt.subplots(1, 11, figsize=(16, 15))
    axes = axes.flatten()
  
    # plot
    ax = axes[0]
    dummy_array = np.array([[[0, 0, 0, 0]]], dtype='uint8')
    ax.set_title("Base reference")
    ax.set_axis_off()
    ax.imshow(dummy_array, interpolation='nearest')
    
    # plot image
    for k,v in images.items():
        ax = axes[k+1]
        ax.imshow(v, cmap=plt.cm.binary)
        ax.set_title(f"True: %s \nPredict: %s" % (class_names[k], class_names[pred_classes[k]]))
        ax.set_axis_off()
    
    plt.tight_layout()
    plt.show()

#### XAI using SHAP
Now use the SHAP library to generate the Shapley values

In [None]:
# select background for shap to take an expectation over
background = x_train[np.random.choice(x_train.shape[0], 1000, replace=False)]

# use the Explainer to explain predictions of the model
explainer = shap.DeepExplainer(model, background)
shap_values = explainer.shap_values(x_test[1:5])

# compute the shapley values
#shap_values = explainer.shap_values(x_test_each_class)

#### plot the Shapley values

In [None]:
#plot_actual_predicted(images_dict, predicted_class)
#print()
#shap.image_plot(shap_values, x_test_each_class * 255)

x_train[np.random.choice(x_train.shape[0], 1000, replace=False)].shape # select a random 1000 datapoints.
# This is the data that is going to form our "overall" that we will compare each prediction to.
#shap.image_plot(shap_values, x_test_each_class, true_labels=class_names)

# plot the feature attributions
shap.image_plot(shap_values, -x_test[1:5], true_labels=class_names)

In [None]:
plot_actual_predicted(images_dict, predicted_class)
print()
shap.image_plot(shap_values, x_test_each_class * 255)