Interpreting neural networks via the STREAK algorithm (streaming weak submodular maximization)
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Failed to load latest commit information.
daisy
examples
sunflowers
LICENSE
README.md
StreakImageRetraining.ipynb
flowers_etsy.jpg
label_image.py
lime_image_streak.py
load_data.py
load_networks.py
phishing.txt
retrain.py
streakFunctions.py
streakInterpretationExample.py
streakRegressionExample.py
tf_predict.py

README.md

STREAK Example Code

Interpreting neural networks via the STREAK algorithm (streaming weak submodular maximization).

  • Ethan R. Elenberg, Alexandros G. Dimakis, Moran Feldman, and Amin Karbasi. ‘‘Streaming Weak Submodularity: Interpreting Neural Networks on the Fly’’, to appear in Proc. Neural Information Processing Systems (NIPS), 2017. arXiv (preprint)

Example

(original) (segmented) (interpretation)
Original Image (top label: daisy) Segmented Image Interpretation for daisy
(original) (segmented) (interpretation)
Original Image (top label: daisy) Segmented Image Interpretation for daisy

Given a black-box neural network and a test image, the algorithm finds a sparse explaination for the network's prediction. First, segment the image into regions. Then rerun the network with most of the image regions replaced by a gray reference image, and record the output. The algorithm returns a sparse set of regions that collectively still activate the network's top label. These examples use InceptionV3 with the last layer retrained to classify different types of flowers.

Requirements

  • Directory 'retrained' that contains the black box models

    -- bottleneck_fc_model.h5 (keras)

    -- classify_image_graph_def.pb and output_labels.txt (tensorflow)

  • Directory 'sunflowers' that contains jpeg images from class sunflowers to use as queries

  • Directory 'daisy' that contains jpeg images from class daisy to use as queries

  • Directory 'outputs' to save the output images

  • LIME, TensorFlow and/or Keras, NumPy, scikit-image, and joblib packages

Usage

The main scripts are streakInterpretationExample.py and streakRegressionExample.py. The Jupyter notebook StreakImageRetraining.ipynb is also available as a convenient walkthrough of streakIntrepretationExample. tf_predict.py can also be used from the command line to load the tensorflow model and predict labels for a list of images.

python streakInterpretationExample.py image1.jpg image2.jpg
python streakRegressionExample.py
python tf_predict.py image1.jpg image2.jpg

A modified LimeImageExplainer class supports 2 new feature selection methods:

  • 'greedy_likelihood' (STREAK Likelihood) is the method described in Section 6.2 of the paper. It does not require generating a set of perturbed images, which leads to faster running times for moderate number of image segments.

  • 'streaming_greedy' (STREAK LIME) is the method described in Section A.8 of the paper. It generates perturbed images but then uses STREAK as the feature selection method instead of forward selection, highest weights, lasso, etc. Like LIME, it scales with the number of perturbed images. Running time is consistently shorter than 'forward_selection' and longer than 'highest_weights'.