# What Do Neural Networks Learn?

You are advised to run this Jupyter Notebook on Google Colab. From the Colab toolbar, select *Runtime* > *Change runtime type* > *T4 GPU* > *Save* before running the Notebook.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from keras.applications import Xception
import keras.applications.xception as xception

from keras.datasets import mnist

from keras import Model
from keras import Input
from keras.layers import Rescaling
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import MaxPooling2D

from keras.optimizers import RMSprop

from keras.callbacks import EarlyStopping

from keras.utils import load_img, img_to_array, array_to_img, save_img

In [None]:
# The third demo in this Notebook requires an extra Python library: opencv-python
# Installing it on your own machine may downgrade your version of numpy - which we don't want.
# So it is is better to run this Notebook on Google Colab.

import cv2

In [None]:
import os
if 'google.colab' in str(get_ipython()):
  from google.colab import drive
  drive.mount('/content/drive')
  base_dir = "./drive/My Drive/Colab Notebooks/" # You may need to change this, depending on where your notebooks are on Google Drive
else:
  base_dir = "."
dataset_dir = os.path.join(base_dir, "datasets")

## Acknowledgments
- The first two pieces of visualization code come from: F. Chollet: *Deep Learning with Python (2nd edn)*, Manning Publications, 2021
- The third piece of visualization code is slightly modified from [Adrian Rosebrock's web site](https://pyimagesearch.com/2020/03/09/grad-cam-visualize-class-activation-maps-with-keras-tensorflow-and-deep-learning/).
- The code for adding spurious correlations to MNIST is adpated from [https://github.com/dtak/rrr/blob/master/rrr/decoy_mnist.py](https://github.com/dtak/rrr/blob/master/rrr/decoy_mnist.py), which is the repo that accompanies the paper: Andrew Slavin Ross, Michael C. Hughes and Finale Doshi-Velez: *Right for the Right Reasons: Training Differentiable Models by Constraining their Explanations*, Proceedings of the Twenty-Sixth International Joint Conference on Artificial Intelligence, pp.2662-2670, 2017


## Visualizing Convolutional Neural Networks

- We'll use three visualizations to gain insight into what a network learns.
- You do not have to understand the code!
- We will run these visualizations on a convolutional neural network called Xception that has been pre-trained on the ImageNet dataset.

In [None]:
# In some cases, we'll just use the base
xception_base = Xception(weights="imagenet", include_top=False)

In [None]:
# In other cases, we'll use the top as well
xception_model = Xception(weights="imagenet", include_top=True)

In [None]:
xception_model.summary()

In [None]:
# We'll also make use of a cat image that we used in a previous lecture.
img_path = os.path.join(dataset_dir, "wikipedia_cats_and_dogs/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg")
img = load_img(img_path, target_size=(299, 299))
img_array = img_to_array(img)
img_tensor = np.expand_dims(img_array, axis=0)

In [None]:
plt.imshow(img_tensor[0].astype("uint8"))
plt.axis("off")
plt.show()

## Visualizations of the activations of convolutional and pooling layers

Create a model that returns the activations of the convolutional and pooling layers of the Xception model:

In [None]:
layer_outputs = []
layer_names = []
for layer in xception_base.layers:
    if isinstance(layer, (Conv2D, MaxPooling2D)):
        layer_outputs.append(layer.output)
        layer_names.append(layer.name)
activation_model = Model(xception_base.input, layer_outputs)

Feed our example image into the model in order to compute the layer activations:

In [None]:
activations = activation_model.predict(img_tensor)

Now we can plot the activations of the layers.

E.g., here are the activations of feature map 1 in convolutional layer 0. What do you think this feature map detects?

In [None]:
plt.matshow(activations[0][0, :, :, 1], cmap="viridis")

E.g., here are the activations of the feature map 9 in convolutional layer 0.  What do you think this feature map detects?

In [None]:
plt.matshow(activations[0][0, :, :, 9], cmap="viridis")

 Now a visualization of all activations of all feature maps in all convolutional and pooling layers:

In [None]:
images_per_row = 16

# Iterate over the layers
for layer_name, layer_activation in zip(layer_names, activations):
    # This is the number of features in the feature map
    n_features = layer_activation.shape[-1]

    # The feature map has shape (1, size, size, n_features)
    size = layer_activation.shape[1]

    # We will tile the activation channels in this matrix
    n_cols = n_features // images_per_row
    display_grid = np.zeros(((size + 1) * n_cols -1,
                             images_per_row * (size + 1) -1))

    # We'll tile each filter into this big horizontal grid
    for col in range(n_cols):
        for row in range(images_per_row):
            channel_index = col * images_per_row + row
            channel_image = layer_activation[0, :, :, channel_index].copy()
            # Post-process the feature to make it visually palatable
            if channel_image.sum() != 0:
                channel_image -= channel_image.mean()
                channel_image /= channel_image.std()
                channel_image *= 64
                channel_image += 128
            channel_image = np.clip(channel_image, 0, 255).astype("uint8")
            display_grid[
                col * (size + 1): (col + 1) * size + col,
                row * (size + 1): (row + 1) * size + row] = channel_image
    # Display the grid
    scale = 1. / size
    plt.figure(figsize=(scale * display_grid.shape[1],
                        scale * display_grid.shape[0]))
    plt.title(layer_name)
    plt.grid(False)
    plt.imshow(display_grid, aspect="auto", cmap="viridis")

plt.show()

- Lower layers are edge dectectors and, because these edges are common, there is a lot of activation.
- Higher in the network, features become more abstract and hence activation is less about the image and more about the class.
- In higher layers, there are cases of almost no activation, meaning the feature is not present at all.

## Visualizations of the inputs that convolutional layers are receptive to

In this visualization, we display the kinds of inputs that feature maps respond to. This is done by gradient ascent on the input space:
- start from a blankish input image
- find the changes to the input that maximise the response of a feature map.

We specify which layer we are interested in. You can change this to any of the other layers, e.g. "block2_sepconv1", "block4_sepconv1", "block10_sepconv1"...


In [None]:
layer_name = "block2_sepconv1"
#layer_name = "block4_sepconv1"
#layer_name = "block10_sepconv1"

In [None]:
layer = xception_base.get_layer(name=layer_name)
feature_extractor = Model(xception_base.input, layer.output)

In [None]:
def compute_loss(image, filter_index):
    activation = feature_extractor(image)
    filter_activation = activation[:, 2:-2, 2:-2, filter_index]
    return tf.reduce_mean(filter_activation)

In [None]:
@tf.function
def gradient_ascent_step(image, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(image)
        loss = compute_loss(image, filter_index)
    grads = tape.gradient(loss, image)
    grads = tf.math.l2_normalize(grads)
    image += learning_rate * grads
    return image

In [None]:
img_width = 200
img_height = 200

def generate_filter_pattern(filter_index):
    iterations = 30
    learning_rate = 10.
    image = tf.random.uniform(
        minval=0.4,
        maxval=0.6,
        shape=(1, img_width, img_height, 3))
    for i in range(iterations):
        image = gradient_ascent_step(image, filter_index, learning_rate)
    return image[0].numpy()

In [None]:
def deprocess_image(image):
    image -= image.mean()
    image /= image.std()
    image *= 64
    image += 128
    image = np.clip(image, 0, 255).astype("uint8")
    image = image[25:-25, 25:-25, :]
    return image

So here are the kinds of inputs that the second channel in layer block3_speconv1 responds to. What do you think it responds to?

In [None]:
plt.imshow(deprocess_image(generate_filter_pattern(filter_index=2)))
plt.axis("off")
plt.show()

Now a visualization for every feature map  in the layer:

In [None]:
all_images = []
for filter_index in range(64):
    image = deprocess_image(
        generate_filter_pattern(filter_index)
    )
    all_images.append(image)

margin = 5
n = 8
cropped_width = img_width - 25 * 2
cropped_height = img_height - 25 * 2
width = n * cropped_width + (n - 1) * margin
height = n * cropped_height + (n - 1) * margin
stitched_filters = np.zeros((width, height, 3))

for i in range(n):
    for j in range(n):
        image = all_images[i * n + j]
        stitched_filters[
            (cropped_width + margin) * i : (cropped_width + margin) * i + cropped_width,
            (cropped_height + margin) * j : (cropped_height + margin) * j
            + cropped_height,
            :,
        ] = image

save_img(os.path.join(base_dir, f"visualizations/filters_for_layer_{layer_name}.png"), stitched_filters)

- If you look at the images (saved in the `visualizations` folder), you'll see that the feature maps in lower layers, e.g. block2_sepconv1, respond to simple edges and colours.
- The feature maps in slighlty later layers, e.g. block4_sepconv1, respond to simple textures made from combinations of edges and colours.
- The feature maps in later layers respond to natural-looking textures resembling feathers, leaves, etc.

## Visualizations of heatmaps that show parts of an image that most contribute to a classification.

For a given input image and a predicted class, this will show which parts of the image were most useful in making the classification. This is sometimes called a **heatmap** or a **saliency map**.
- For every pixel, we compute a score indicating how important that pixel is in predicting the class.
- We display the scores as a heatmap.

This can be helpful in debugging models: we can see whether the model is paying attention to the 'right' parts of the image.

In [None]:
class GradCAM:
    def __init__(self, model, classIdx, layerName=None):
        # store the model, the class index used to measure the class
        # activation map, and the layer to be used when visualizing
        # the class activation map
        self.model = model
        self.classIdx = classIdx
        self.layerName = layerName
        # if the layer name is None, attempt to automatically find
        # the target output layer
        if self.layerName is None:
            self.layerName = self.find_target_layer()
            
    def find_target_layer(self):
        # attempt to find the final convolutional layer in the network
        # by looping over the layers of the network in reverse order
        for layer in reversed(self.model.layers):
            # check to see if the layer has a 4D output
            if len(layer.output.shape) == 4:
                return layer.name
        # otherwise, we could not find a 4D layer so the GradCAM
        # algorithm cannot be applied
        raise ValueError("Could not find 4D layer. Cannot apply GradCAM.")
        
    def compute_heatmap(self, image, eps=1e-8):
        # construct our gradient model by supplying (1) the inputs
        # to our pre-trained model, (2) the output of the (presumably)
        # final 4D layer in the network, and (3) the output of the
        # softmax activations from the model
        gradModel = Model(
            inputs=[self.model.inputs],
            outputs=[self.model.get_layer(self.layerName).output,
                     self.model.output])
        # record operations for automatic differentiation
        with tf.GradientTape() as tape:
            # cast the image tensor to a float-32 data type, pass the
            # image through the gradient model, and grab the loss
            # associated with the specific class index
            inputs = tf.cast(image, tf.float32)
            (convOutputs, predictions) = gradModel(inputs)
            loss = predictions[:, self.classIdx]
        # use automatic differentiation to compute the gradients
        grads = tape.gradient(loss, convOutputs)
        # compute the guided gradients
        castConvOutputs = tf.cast(convOutputs > 0, "float32")
        castGrads = tf.cast(grads > 0, "float32")
        guidedGrads = castConvOutputs * castGrads * grads
        # the convolution and guided gradients have a batch dimension
        # (which we don't need) so let's grab the volume itself and
        # discard the batch
        convOutputs = convOutputs[0]
        guidedGrads = guidedGrads[0]
        # compute the average of the gradient values, and using them
        # as weights, compute the ponderation of the filters with
        # respect to the weights
        weights = tf.reduce_mean(guidedGrads, axis=(0, 1))
        cam = tf.reduce_sum(tf.multiply(weights, convOutputs), axis=-1)
        # grab the spatial dimensions of the input image and resize
        # the output class activation map to match the input image
        # dimensions
        (w, h) = (image.shape[2], image.shape[1])
        heatmap = cv2.resize(cam.numpy(), (w, h))
        # normalize the heatmap such that all values lie in the range
        # [0, 1], scale the resulting values to the range [0, 255],
        # and then convert to an unsigned 8-bit integer
        numer = heatmap - np.min(heatmap)
        denom = (heatmap.max() - heatmap.min()) + eps
        heatmap = numer / denom
        heatmap = (heatmap * 255).astype("uint8")
        # return the resulting heatmap to the calling function
        return heatmap
        
    def overlay_heatmap(self, heatmap, image, alpha=0.5,
                        colormap=cv2.COLORMAP_VIRIDIS):
        # apply the supplied color map to the heatmap and then
        # overlay the heatmap on the input image
        heatmap = cv2.applyColorMap(heatmap, colormap)
        output = image * alpha + heatmap * (1 - alpha)
        output = output.astype(int)
        # return a 2-tuple of the color mapped heatmap and the output,
        # overlaid image
        return (heatmap, output)

Let's apply it to our cat image.

In [None]:
preds = xception_model.predict(img_tensor)
i = np.argmax(preds[0])

cam = GradCAM(xception_model, i)
heatmap = cam.compute_heatmap(img_tensor)
# resize the resulting heatmap to the original input image dimensions
# and then overlay heatmap on top of the image
heatmap = cv2.resize(heatmap, (img_array.shape[1], img_array.shape[0]))
(heatmap, output) = cam.overlay_heatmap(heatmap, img_array, alpha=0.5)

In [None]:
plt.matshow(heatmap)

In [None]:
plt.matshow(output)

## Shortcuts

In [None]:
import math

def show_images(images):
    num_images = len(images)
    num_per_row = 5
    num_rows = math.ceil(num_images / num_per_row)
    fig, axes = plt.subplots(num_rows, num_per_row, figsize=(num_per_row, num_rows))
    for i, image in enumerate(images):
        r = i // num_per_row
        c = i % num_per_row
        ax = axes[c] if num_rows == 1 else axes[r, c]
        ax.imshow(image, cmap=plt.cm.binary, interpolation="nearest")
        ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
def augment_image(image, digit, mult=25):
  img = image.copy()
  fwd = [0,1,2,3]
  rev = [-1,-2,-3,-4]
  dir1 = fwd if np.random.rand() > 0.5 else rev
  dir2 = fwd if np.random.rand() > 0.5 else rev
  for i in dir1:
    for j in dir2:
      img[i][j] = 255 - mult * digit
  return img

In [None]:
def augment_images(images, labels=None, mult=25):
  digits = range(10)
  l, h, w, d = images.shape
  augmented_images = np.zeros(shape=(l, h, w, d))
  for i in range(0, l):
    augmented_images[i] = augment_image(images[i], np.random.choice(digits) if labels is None else labels[i])
  return augmented_images

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [None]:
X_train = X_train.reshape((60000, 28, 28, 1))
X_test = X_test.reshape((10000, 28, 28, 1))

In [None]:
inputs = Input(shape=(28, 28, 1))
x = Rescaling(scale=1./255)(inputs)
x = Conv2D(filters=64, kernel_size=(3, 3), activation="relu")(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
outputs = Dense(units=10, activation="softmax")(x)
mnist_model = Model(inputs, outputs)

In [None]:
mnist_model.compile(optimizer=RMSprop(learning_rate=0.0001), loss="sparse_categorical_crossentropy", metrics=["accuracy"])

In [None]:
X_train_aug = augment_images(X_train, y_train)

In [None]:
show_images(X_train_aug[y_train == 3][:5])

In [None]:
show_images(X_train_aug[y_train == 7][:5])

In [None]:
mnist_model.fit(X_train_aug, y_train, epochs=20, batch_size=32, verbose=0, validation_split=0.25,
                  callbacks=[EarlyStopping(monitor="val_loss", patience=2, restore_best_weights=True)])

In [None]:
X_test_aug = augment_images(X_test, y_test)

In [None]:
mnist_model.evaluate(X_test_aug, y_test, verbose=0)[1]

In [None]:
X_test_aug_rand = augment_images(X_test)

In [None]:
show_images(X_test_aug_rand[y_test == 3][:5])

In [None]:
mnist_model.evaluate(X_test_aug_rand, y_test, verbose=0)[1]

In [None]:
correct = X_test_aug[np.argmax(mnist_model.predict(X_test_aug_rand), axis=1) == y_test]
incorrect = X_test_aug[np.argmax(mnist_model.predict(X_test_aug_rand), axis=1) != y_test]

In [None]:
# Let's look at the fourth image it gets right and the fourth it gets wrong.
img = correct[4]
img_array = img_to_array(img)
img_tensor = np.expand_dims(img_array, axis=0)

In [None]:
preds = mnist_model.predict(img_tensor)
i = np.argmax(preds[0])

cam = GradCAM(mnist_model, i)
heatmap = cam.compute_heatmap(img_tensor)
heatmap = cv2.resize(heatmap, (img_array.shape[1], img_array.shape[0]))
(heatmap, output) = cam.overlay_heatmap(heatmap, img_array, alpha=0.5)

In [None]:
plt.matshow(heatmap)

In [None]:
# Let's look at the fourth image it gets wrong.
img = incorrect[4]
img_array = img_to_array(img)
img_tensor = np.expand_dims(img_array, axis=0)

In [None]:
preds = mnist_model.predict(img_tensor)
i = np.argmax(preds[0])

cam = GradCAM(mnist_model, i)
heatmap = cam.compute_heatmap(img_tensor)
heatmap = cv2.resize(heatmap, (img_array.shape[1], img_array.shape[0]))
(heatmap, output) = cam.overlay_heatmap(heatmap, img_array, alpha=0.5)

In [None]:
plt.matshow(heatmap)

<b>Discussion question:</b> Do you think any of these visualizations would be useful to someone who wanted to debug or audit a model?