## DeepExplain - Inception V3 for natural image classification
### Tensorflow-Slim model example

In [2]:
%%bash
# Download TF checkpoints

wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar -xvzf inception_v3_2016_08_28.tar.gz -C data/models
rm inception_v3_2016_08_28.tar.gz

wget http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
tar -xvzf adv_inception_v3_2017_08_18.tar.gz -C data/models
rm adv_inception_v3_2017_08_18.tar.gz

inception_v3.ckpt
adv_inception_v3.ckpt.data-00000-of-00001
adv_inception_v3.ckpt.index
adv_inception_v3.ckpt.meta


--2017-12-15 19:41:00--  http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 216.58.205.48, 2a00:1450:4002:806::2010
Connecting to download.tensorflow.org (download.tensorflow.org)|216.58.205.48|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 100885009 (96M) [application/x-tar]
Saving to: ‘inception_v3_2016_08_28.tar.gz’

     0K .......... .......... .......... .......... ..........  0% 4.83M 20s
    50K .......... .......... .......... .......... ..........  0% 5.75M 18s
   100K .......... .......... .......... .......... ..........  0% 23.1M 14s
   150K .......... .......... .......... .......... ..........  0% 23.4M 11s
   200K .......... .......... .......... .......... ..........  0% 24.8M 10s
   250K .......... .......... .......... .......... ..........  0% 26.7M 9s
   300K .......... .......... .......... .......... ..........  0% 34.7M 8s
   350K .......... .......... ....

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tempfile, sys, os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
from scipy.misc import imread
import tensorflow as tf
from tensorflow.contrib.slim.nets import inception

slim = tf.contrib.slim
    
from deepexplain.tensorflow import DeepExplain

In [3]:
def load_images(batch_shape):
    images = np.zeros(batch_shape)
    filenames = []
    idx = 0
    batch_size = batch_shape[0]
    for filepath in tf.gfile.Glob(os.path.join('data/images', '*.png')):
        print (filepath)
        with tf.gfile.Open(filepath,'rb') as f:
            image = imread(f, mode='RGB').astype(np.float) / 255.0
    # Images for inception classifier are normalized to be in [-1, 1] interval.
    images[idx, :, :, :] = image * 2.0 - 1.0
    filenames.append(os.path.basename(filepath))
    idx += 1
    if idx == batch_size:
        yield filenames, images
        filenames = []
        images = np.zeros(batch_shape)
        idx = 0
    if idx > 0:
        yield filenames, images

In [5]:
batch_shape = [1, 299, 299, 3]
num_classes = 1001
checkpoint = 'data/models/inception_v3.ckpt'

# Run computation
sess = tf.Session()

with DeepExplain(session=sess, graph=sess.graph) as de:
    X = tf.placeholder(tf.float32, shape=batch_shape)

    with slim.arg_scope(inception.inception_v3_arg_scope()):
        _, end_points = inception.inception_v3(X, num_classes=num_classes, is_training=False)

    logits = end_points['Logits']
    yi = tf.argmax(logits, 1)


    saver = tf.train.Saver(slim.get_model_variables())
    saver.restore(sess, checkpoint)

    for filenames, xi in load_images(batch_shape):
        labels = sess.run(yi, feed_dict={X: xi})
        print (filenames, labels)

INFO:tensorflow:Restoring parameters from data/models/inception_v3.ckpt
data/images/1c2e9fe8b0b2fdf2.png
['1c2e9fe8b0b2fdf2.png'] [58]


In [None]:
from utils import plot, plt

with DeepExplain(session=sess) as de:
    logits = end_points['Logits']
    yi = tf.to_float(tf.argmax(logits, 1))
    attributions = {
        # Gradient-based
        'Saliency maps':        de.explain('saliency', logits * yi, X, xi),
        'Gradient * Input':     de.explain('grad*input', logits * yi, X, xi),
        'Integrated Gradients': de.explain('intgrad', logits * yi, X, xi),
        'Epsilon-LRP':          de.explain('elrp', logits * yi, X, xi),
        'DeepLIFT (Rescale)':   de.explain('deeplift', logits * yi, X, xi),
        #Perturbation-based
        '_Occlusion [10x10]':      de.explain('occlusion', logits * yi, X, xi, window_shape=(10,10,3))
    }

DeepExplain: running "saliency" explanation method (1)
DeepExplain: running "grad*input" explanation method (2)
DeepExplain: running "intgrad" explanation method (3)
DeepExplain: running "elrp" explanation method (4)
DeepExplain: running "deeplift" explanation method (5)
DeepExplain: running "occlusion" explanation method (6)
Input shape: (299, 299, 3); window_shape (10, 10, 3); step 1
