Skip to content
/ exigo Public

Explainable AI tools for deciphering how neural networks. Includes saliency maps, confusion matrices, and feature representation examples.


Notifications You must be signed in to change notification settings


Repository files navigation


📒 Description


pre-commit pdoc pytest

Neural networks are difficult to understand because of their 'black box' nature. Here we will describe various different ways a researcher can interrogate their neural networks to describe what is happening. This repo contains a pipeline for extracting information from an MNIST dataset. However, the MNIST pipeline can be replaced with your preferred dataloaders.

∮ Running Pipeline

The script has everything you need to get started. It will:

  1. Download the MNIST dataset
  2. Train a model on MNIST
  3. Generate plots:
    • confusion matricies
    • saliency plots
    • projection embeddings
  4. Save off the best model

Files will be saved the save folder. You can also exercise the full pipeline via pytest.

Example of one training epoch:

Global step: 0
training ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:21
evaluation ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:02
Epoch:               0/2
    Training   Loss: 2.285
    Evaluation Loss: 2.198

🤔 Confusion Matrix

Confusion matricies are designed to show how well a models' predictions performed. There are two variables of interest when examining a confusion matrix: precision and recall. Each variable can be extracted from a confusion matrix to determine how well a network properly classifies data into separate categories.

  • precision = $\Large\frac{p}{\sum{p_{all}}}$

  • recall = $\Large\frac{p}{\sum{t}}$


    • $p$: is correct predictions
    • $p_{all}$: is the total of all predictions made for that class
    • $t$: is the true labels for that class

    Precision is the percentage of predictions that were correct divided by the total number of predictions. Recall is the percentage of correct predictions divided by the total number of true positives.


Figure 1: Confusion matrix at epoch=0


Figure 2: Confusion matrix at epoch=1

🗺️ Saliency Maps

Saliency maps help a designer understand what parts of a an input space contributed toward the decision making process. In order to build a saliency you need to find the backpropagation gradients with respect to:

  • $x$ -> $y$
  • $x_{baseline}$ -> $y_{baseline}$

The gradients of $x$ tell you where the network focused on when given the true label. The gradients of $x_{baseline}$ tell you where the network focused on when given no label. This is important because you only want to measure the focus of the network above the baseline value. This is calculated as:

  • $\Delta$ = $x_{gradient}$ - $x_{baseline\ gradient}$

Finally you can multiply the $\Delta$ with the original source image $x$ to determine where the network focused on during a forward pass. The areas with higher values are more relevant toward the decision process.

The following images are generated:

  • source: Shows the source image.
  • baseline: Shows where the neural network focuses attention with an empty image.
  • saliency: Shows the activations where the network is focused on for the source image.
  • delta: Shows the delta which is the saliency minus the baseline noise.
  • overlay: Shows saliency multiplied versus the source. This is what the neural network is focusing on when performing its inference function!


Figure 3: Saliency map before training. Epoch=0.


Figure 4: Saliency map after 2 epoch of training.


📽️ Projection Embeddings

A neural network will cluster data in high dimensional space during training. However, humans cannot visualize high dimensional data. We need it to be projected down into three dimensions or lower. This is where t-SNE is helpful.

The general idea with t-SNE is that we want to create a lower level projection of the data that matches as closely as possible with the original dimensionality of the data. To create the data points for t-SNE, a set of $x$ inputs is passed through a neural network to output the corresponding $y$ embeddings. These $y$ embeddings represent a point in high dimensional space.

t-SNE uses the t-distribution to measure the statistical distance between a $y$ data point with respect to every other $y$ data point. The distance calculation is basically the symmetrical KL divergence. When projecting from high dimensions to low dimensions, the solver tries to ensure that the relationship between the $y$ data points are preserved. This is how clusters emerge when using t-SNE. Data points that are close to each other in high dimensions are shown close together in low dimension.

When t-SNE projects high dimensional data to a lower dimension it uses an iterative solver that performs the following:

1) Randomly project high dimension points to lower dimension
2) Loop:
    * Iteratively move data points in the lower dimension to match their distances in high dimensions.

A complete explaination can be found video.

After running the pipeline execute the following to see t-SNE in action:

tensorboard --logdir=save/logs

Under the projector tab you will find the visualization tools. You must use the Chrome browser for this to work. See example below.


Figure 5: t-SNE of embeddings.

🧭 Website Documentation

To generate the HTML execute the following:

nox -r -s autodoc

🖋️ Authors

William Li

✨ References


Explainable AI tools for deciphering how neural networks. Includes saliency maps, confusion matrices, and feature representation examples.








No releases published


No packages published