## Interpretating what Convolutional Neural Networks Learn

### On the Importance of Model Interpretability in Healthcare

When we started out on our Machine Learning Journey, most of us have seen authors describe the conundrum of the **trade off between prediction performance and a model (hypothesis)'s interpretability**[^islr]. The general consensus, though not always correct, is that **more complex** models will have **high performance measures but low interpretability** and **less complex** models, the opposite.

One might ask, when should one consider the usage of a **complex** or a **simpler** one? Here is a rule of thumb for you to follow.

- If the only interest is in **prediction**, for example, one company seeks to develop an algorithm that predicts the **crypto price**, then it is likely that the performance measure should be maximized over interpretability.
- On the other hand, if our interest is in **inference**[^inference], for example in a **healthcare** setting where our medical use case is to predict whether a patient has **melanoma (skin cancer)**, then **interpretability** is more important that **high performance**. This is because the **models** are really not the only one making the decision, healthcare practitioners will need to understand the cause of the decision made by the model. In short, doctors wanted to know and agree on which particular area of the **skin cell** propels the model to make the decision that it is cancerous or not.
- One more point, can help to do error analysis on where the model did wrong. For example, a beginner trained a model to classify red and white blood cells, started from MNIST, he naturally greyed out the blood cells and received poor results. Doing some error analysis to interpret what went wrong easily tells us that by greying out the cells, the model get confused because they looked very similar, one important feature (i.e. color) is missing.

For more understanding in this section, the book[^interpretable_ml_book] Interpretable Machine Learning by Christoph Molnar is a good read.

[^islr]: An Introduction to Statistical Learning, James, G., Witten, D., Hastie, T., & Tibshirani, R. pp.24-25
[^inference]: Note that the difference between inference and prediction is that inference can be thought of as one step beyond prediction, where we are not only concerned with the outputs of our model but also want to be able to extract a meaningful relationship between our input features and our predictions. Inference may be understood as prediction in the ML community, but not in the statistics community.

[^interpretable_ml_book]: https://christophm.github.io/interpretable-ml-book/

### Computer Vision and Convolutional Neural Networks

The computer vision has seen a bloom in the recent decade, empowering many use cases, including but not limited to the **healthcare, automobile and facial recoginition industry**. Most computer vision problems uses a type of neural networks **Convolutional Neural Networks (CNN)**, though recent breakthroughs show more promising results in **vision transformers**[^vision_transformers]. We will however, focus more on **CNN** in this section. For the longest time, CNNs are regarded as **black box model** simply because it is very **complex and hence difficult to interpret**. However, in recent times, as the need to **interpret models** grows, there are quite some methods to give a glimpse of what your CNN is looking at.

[^vision_transformers]: https://en.wikipedia.org/wiki/Vision_transformer

### CNN is a black box?

Neural Networks are well defined mathematically, so why we cannot use the same way we do to say Logistic Regression to interpret the model? Well ... the feature importance is not so easy to decode as neural networks in computer vision are interpreted pixel level and it is not logical to derive feature importance at a pixel level. We need something at a spatial level.

For example, we plot a grayscale image of a number 3 taken from MNIST by pixel level (i.e. $28 \times 28$), then it is not immediately obvious how we should recover the "importance" of any single pixel.




<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/pixel_level_mnist.PNG" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>MNIST pixel level plot; By Hongnan G.</b>
</p>

### High level understanding of CNN

The below is referenced from the [Interpretable Machine Learning](https://christophm.github.io/interpretable-ml-book/).

We start off with the question, why cannot we use a normal model say, SVM or Logistic Regression to predict a cat image? Can't we just flatten the pixels and feed it as input to the models?

- Flattened pixels lose **spatial level information**, it only encodes information sequentially and high level features are not captured. Intuition is that a cat is a cat because of its eyes and ears (example) but when flattened these features may not be clustered together. We may need insane **Feature Engineering** to make classical ML models work (i.e. if you can re-construct the features of a cat's eyes using the flattened pixels then you may be successful).
- The power of CNN is that it learns high-level spatial features such as **colors, edges and patterns** that is unique to the image. This is enabled by the number of **hidden layers** that did many transformations to the inputs. So in a sense, the hidden layers of CNN are **implicitly performing feature engineering**.

> Let us detail a high level outline of a "life cycle of a CNN".

1. The input (image) is usually of size $(C, H, W)$ is fed into the CNN. Note we **do not flatten the image**.
2. In these **CNN layers**, the network first learns **simple features** such as the cat's edges and shapes, then as it progress to later layers, it learns **highly abstract features** such as more **complex textures and patterns**.
3. After propagating to the last layer of the **Convolutional Layer**, we will then use a type of pooling and flatten the learned features and connect it to fully connected layers and predict the classes.

<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/cat_cnn_black_box.jpg" width="300" height="300" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>Cat and CNN; By Hongnan G.</b>
</p>

## CNN Interpretation

### Feature Visualization Through Convolutional Layers

As mentioned in the previous section on the high level overview of CNN, an immediate solution is to ask if we can visualize what each layer's output is showing. 

This method is useful for understanding:

- How successive CNN layers transform their inputs. 
- What each combination of filters does, good for having an intuition of whether the filter detects edges, shapes or more.
- Visualizing these filters can give us an understanding of the visual pattern that the CNN is capturing.

#### Visualization 

Here is a visual of the conv layers' outputs.

<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/vgg16_conv_layers_visualization.PNG" width="800" height="800" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>VGG16 Conv layers on a Cat; By Hongnan G.</b>
</p>

#### Takeaways

Our takeaways are:

- The first few layers act as a collection of various **edge and shape** detectors and the activations retain almost all of the information from the original input image.
- As you go deeper, the layers begin to encode high and abstract features, the features become less "informative and obvious" and more "generic" (more on this intuition later).
-  The sparsity of the activations increases with the depth of the layer: in the first layer, almost all filters are activated by the input image, but in the following layers, more and more filters are blank. This means the pattern encoded by the filter isn't found in the input image.

---

#### Intuition

As Francis Chollet mentioned, a deep CNN acts as an *information distillation pipeline*, but what is it? Why is it that as you go deeper, the image of a cat looks **less precise and less like a real cat**? The analogy I use (similar to hise) gives you intuition:

> Imagine you were tasked to recognize a cat, you will do so instantly. Now you are also tasked to draw a cat, surely you can do so if you can recognize it? You started to draw a cat and compare it with a real cat. You realize that we cannot really remember the specific details of a cat but we can draw the abstract (generic) cat. The neural networks is like our brain, **where we manage to recognize an image by filtering out the irrelevant details and transform it to high-level abstract features**. This is what happens to a CNN and this is ideal! We do not want the model to remember too specific details of an image in fear of it not being able to generalize.


<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/cat_hn.jpg" width="500" height="600" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>A cat image from me</b>
</p>

<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/cats_unsplash.jpg" width="500" height="600" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>A cat image from Unsplash</b>
</p>

#### Readings

Readings over this section:

- François Chollet: Deep Learning With Python pp.262-267
- Christoph Molnar: Interpretable Machine Learning section. 10.1

### Gradcam

- Interpret decision by determining which feature in our inputs had the highest contribution. If model predicted a cat, is it the eyes, ears or body shape that defined the class?
- Gradients can be very useful in this. Intuition is that gradients measure the effect on the outputs caused by some inputs. 
    - i.e. $y = f(x) = 2x$, then every change of 1 unit of change in $x$ causes $2$ units of change in our output $y$. The gradient is 2 here and measures the **rate of change of $y$ with respect to $x$**.
    - The same analogy applies here, can we find the pixels $x$ in the image that contributes to the target $y$? In practice however, we often denote a loss function $\mathcal{L}(\hat{f(\mathbf{x})}, \mathbf{y})$ to minimize it and we can compute the gradient of this loss function with respect to the inputs. That is to say, if our loss $\mathcal{L}$ is low, then our model is doing something right and we can examine the gradients of the loss with respect to $x$ to find out more.
- If we are interested in the cat class, we focus only on the $y_{cat}$.
    - Feature maps of a cat: $A_1, A_2, ..., A_k$, each contributes in making the final decision in what $\hat{y}$ is.
    - Computing gradient of $y_{cat}$ with respect to $A_k$ will give us the rate of change of the feature maps with respect to the target.
    - For each feature map, we compute the gradients. For example if we have 8 32 by 32 feature maps, we compute gradients for all $8 x 32 x 32$ pixels and average them channel wise. 
    - Now we have a single 32 by 32 feature map with all the gradient info, we then apply `ReLU` to it, where we set negative values to 0 because we are only interested in the class cat. i.e. negative values in gradient does not mean non-importance, it only pulls the prediction to the other classes.
    - Now we have a 32 by 32 feature map with only positive values at certain areas, we need to map it back to the original image. For example if the original image of the cat is 320 x 320, then we need to scale it back to overlay back to the original image. 
    - The overlayed image will show a heatmap on where the gradients are non-negative, highlighting areas of focus.

<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/gradcam_1.jpg" width="500" height="600" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>Gradcam 1</b>
</p>

<img src="https://storage.googleapis.com/reighns/reighns_ml_projects/docs/deep_learning/computer_vision/gradcam_and_variants/gradcam_2.jpg" width="500" height="600" style="margin-left:auto; margin-right:auto"/>
<p style="text-align: center">
    <b>Gradcam 2</b>
</p>

## References

- https://christophm.github.io/interpretable-ml-book/
- https://distill.pub/2017/feature-visualization/
- Deep Learning with Python by Francois Chollet