Skip to content

LevinRoman/parameter-space-saliency

Repository files navigation

Where do Models go Wrong? Parameter-Space Saliency Maps for Explainability

This repository contains the implementation of parameter-saliency methods from our paper Where do Models go Wrong? Parameter-Space Saliency Maps for Explainability .

Abstract: Conventional saliency maps highlight input features to which neural network predictions are highly sensitive. We take a different approach to saliency, in which we identify and analyze the network parameters, rather than inputs, which are responsible for erroneous decisions. We find that samples which cause similar parameters to malfunction are semantically similar. We also show that pruning the most salient parameters for a wrongly classified sample often improves model behavior. Furthermore, fine-tuning a small number of the most salient parameters on a single sample results in error correction on other samples that are misclassified for similar reasons. Based on our parameter saliency method, we also introduce an input-space saliency technique that reveals how image features cause specific network components to malfunction. Further, we rigorously validate the meaningfulness of our saliency maps on both the dataset and case-study levels.

Getting started

This repo uses Python 3. To install the requirements, run

pip install -r requirements.txt

Basic Use

The script input_saliency.py computes both the parameter-saliency profile of an image which allows to find misbehaving filters in a neural network responsible for misclassification of a given image. In addition, the script computes the image-space saliency which highlights pixels which drive the high filter saliency values.

To compute the parameter saliency profile for a given image, the script accepts

  • either path to the raw image + image target label
python3 parameter_and_input_saliency.py --model resnet50 --image_path raw_images/great_white_shark_mispred_as_killer_whale.jpeg --image_target_label 2
  • or reference_id -- the index of the given image in ImageNet validation set.
python3 parameter_and_input_saliency.py --reference_id 107 --k_salient 10

here --reference_id specifies the image id from ImageNet validation set

--k_salient specifies the number of top salient filters to use for the input-space visualization

The resulting plots (input space colormap and filter saliency plot) will be saved to /figures

Demo

The demo raw image is in /raw_images. The results are in /figures.

Project Organization

├── README.md
├── LICENSE
├── requirements.txt 
├── utils.py  <- helper functions       
├── parameter_and_input_saliency.py  <- main script which computes both input saliency and parameter saliency
│
├── figures <- folder for resulting figures
├── helper_objects <- precomputed objects to speed up computation (inference results on ImageNet valset and parameter saliency mean and std for standardization)
│   ├─ resnet50   
│   ├─ densenet121
│   ├─ inception_v3
│   └── vgg19
├── raw_images <- images to use for parameter space saliency computation and for input space saliency visualization
└── parameter_saliency
    └── saliency_model_backprop  <- script with SaliencyModel class, parameter saliency implementation 

About

Parameter-Space Saliency Maps for Explainability

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages