Neural networks are difficult to understand because of their 'black box' nature. Here we will describe various different ways a researcher can interrogate their neural networks to describe what is happening. This repo contains a pipeline for extracting information from an MNIST dataset. However, the MNIST pipeline can be replaced with your preferred dataloaders.
The runner.py
script has everything you need to get started. It will:
- Download the
MNIST
dataset - Train a model on
MNIST
- Generate plots:
- confusion matricies
- saliency plots
- projection embeddings
- Save off the best model
Files will be saved the save
folder. You can also exercise the full pipeline via pytest
.
Example of one training epoch:
Global step: 0
training ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:21
evaluation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:02
Epoch: 0/2
Training Loss: 2.285
Evaluation Loss: 2.198
Confusion matricies are designed to show how well a models' predictions performed. There are two variables of interest when examining a confusion matrix: precision and recall. Each variable can be extracted from a confusion matrix to determine how well a network properly classifies data into separate categories.
-
precision =
$\Large\frac{p}{\sum{p_{all}}}$ -
recall =
$\Large\frac{p}{\sum{t}}$ Where:
-
$p$ : is correct predictions -
$p_{all}$ : is the total of all predictions made for that class -
$t$ : is the true labels for that class
Precision is the percentage of predictions that were correct divided by the total number of predictions. Recall is the percentage of correct predictions divided by the total number of true positives.
-
Figure 1: Confusion matrix at epoch=0
Figure 2: Confusion matrix at epoch=1
Saliency maps help a designer understand what parts of a an input space contributed toward the decision making process. In order to build a saliency you need to find the backpropagation gradients with respect to:
-
$x$ ->$y$ -
$x_{baseline}$ ->$y_{baseline}$
The gradients of
-
$\Delta$ =$x_{gradient}$ -$x_{baseline\ gradient}$
Finally you can multiply the
The following images are generated:
- source: Shows the source image.
- baseline: Shows where the neural network focuses attention with an empty image.
- saliency: Shows the activations where the network is focused on for the source image.
- delta: Shows the delta which is the saliency minus the baseline noise.
- overlay: Shows saliency multiplied versus the source. This is what the neural network is focusing on when performing its inference function!
Figure 3: Saliency map before training. Epoch=0.
Figure 4: Saliency map after 2 epoch of training.
A neural network will cluster data in high dimensional space during training. However, humans cannot visualize high dimensional data. We need it to be projected down into three dimensions or lower. This is where t-SNE
is helpful.
The general idea with t-SNE
is that we want to create a lower level projection of the data that matches as closely as possible with the original dimensionality of the data. To create the data points for t-SNE
, a set of
t-SNE
uses the t-distribution to measure the statistical distance between a t-SNE
. Data points that are close to each other in high dimensions are shown close together in low dimension.
When t-SNE
projects high dimensional data to a lower dimension it uses an iterative solver that performs the following:
1) Randomly project high dimension points to lower dimension
2) Loop:
* Iteratively move data points in the lower dimension to match their distances in high dimensions.
A complete explaination can be found video
.
After running the pipeline execute the following to see t-SNE
in action:
tensorboard --logdir=save/logs
Under the projector
tab you will find the visualization tools. You must use the Chrome
browser for this to work. See example below.
Figure 5: t-SNE of embeddings.
To generate the HTML execute the following:
nox -r -s autodoc
William Li
-
Simonyan, K., Vedaldi, A., & Zisserman, A. (2013). Deep inside convolutional networks: visualising image classification models and saliency maps. International Conference on Learning Representations. https://www.robots.ox.ac.uk/~vgg/publications/2014/Simonyan14a/simonyan14a.pdf
-
Zeiler, M. D., & Fergus, R. (2014). Visualizing and understanding convolutional networks. In Lecture Notes in Computer Science (pp. 818–833). https://doi.org/10.1007/978-3-319-10590-1_53
-
Starter, J. (2017, September 18). StatQuest: T-sne, clearly explained. YouTube. https://www.youtube.com/watch?v=NEaUSP4YerM
-
Laurens Van Der Maaten, & Hinton, G. (2008). Visualizing Data using t-SNE. Journal of Machine Learning Research, 9, 2579–2605. https://lvdmaaten.github.io/publications/papers/JMLR_2008.pdf
-
Olah, C., Satyanarayan, A., Johnson, I., Carter, S., Schubert, L., Ye, K., & Mordvintsev, A. (2018). The building blocks of interpretability. Distill, 3(3). https://doi.org/10.23915/distill.00010
-
Hilton, J., Cammarata, N., Carter, S., Goh, G., & Olah, C. (2020). Understanding RL vision. Distill, 5(11). https://doi.org/10.23915/distill.00029
-
Visualizing Representations: Deep Learning and Human Beings - colah’s blog. (n.d.). http://colah.github.io/posts/2015-01-Visualizing-Representations/