In [12]:
import tensorflow as tf
import numpy as np
import input_data

from tabulate import tabulate

%matplotlib inline
from matplotlib import pyplot as plt
import pylab

import sys
sys.path.append("..")

mnist = input_data.read_data_sets('data', one_hot=True)
batch_size = mnist.test.images.shape[0]

def gen_image(ax, arr, title=''):
    two_d = (np.reshape(arr, (28, 28)) * 255).astype(np.uint8)
    ax.set_title(title)
    ax.imshow(two_d, interpolation='nearest', cmap=pylab.gray())


def print_images_and_heatmaps(xs, hs, yp, yl, iterations):
    
    y_preds = np.argmax(yp, axis=1)
    ground_truth = np.argmax(yl, axis=1)
    
    images = xs.reshape([batch_size,28,28])
    images = (images + 1)/2.0
    
    heatmaps = hs.reshape([batch_size,28,28])
    for i in range(iterations):
        print('%d: Label: %d , Prediction: %d , prob: %f' % (i, ground_truth[i], y_preds[i], yp[i,y_preds[i]]))
        fig = plt.figure()
        ax1 = fig.add_subplot(2,2,1, adjustable='box', aspect=1)
        ax2 = fig.add_subplot(2,1,1, adjustable='box', aspect=1)

        gen_image(ax1, images[i], title='Input')
        gen_image(ax2, heatmaps[i], title='Heatmap')

        plt.show()

# input dict creation as per tensorflow source code
def feed_dict(mnist, train):    
    if train:
        xs, ys = mnist.train.next_batch(batch_size)
    else:
        xs, ys = mnist.test.next_batch(batch_size)
    return (2*xs)-1, ys


with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models/fully-connected.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./models/'))
    
    graph = tf.get_default_graph()
    
    x = graph.get_tensor_by_name('x-input:0')
    ground_truth = graph.get_tensor_by_name('y-input:0')
    
    y_pred = graph.get_tensor_by_name('y_pred:0')
    accuracy = graph.get_tensor_by_name('accuracy:0')
    
    relevance_layerwise = []
    accumulated_relevances = []
    
    layers = []
    R = y_pred
    for i in range(1,3):
        relevance_layerwise.append(graph.get_tensor_by_name("R{}:0".format(i)))
        accumulated_relevances.append(np.zeros(relevance_layerwise[i-1].shape))
        
        d = dict()
        d['b'] = graph.get_tensor_by_name('linear_{}/biases:0'.format(i))
        d['W'] = graph.get_tensor_by_name('linear_{}/weights:0'.format(i))
        layers.append(d)
    
    for r in relevance_layerwise:
        print(r)
    
    final_acc = 0.
    iterations = 78

    for i in range(iterations):
        print("Iteration {} out of {}".format(i, iterations))
        feed_d = {x: mnist.test.images[i*128:(i+1)*128], ground_truth: mnist.test.labels[i*128:(i+1)*128]}
        acc, relevances = sess.run([accuracy, relevance_layerwise], feed_dict=feed_d)
        final_acc += acc
        for idx, r in enumerate(relevances):
            accumulated_relevances[idx] += r
    
    final_acc /= iterations
    
    final_relevances = []
    for r in accumulated_relevances:
        final_relevances.append(np.sum(r, axis=0))
    
    print("--------------------")
    
    print("Accuracy: {}".format(final_acc))
    
    print("--------------------")
    
    print("Relevances")
    for r in final_relevances:
        print(r.shape)
    
    print("--------------------")
    
    
    ### START PRUNING
    acc_before_pruning = final_acc
#     pruning_threshold = 1
    pruning_thresholds = [i*1. for i in range(1, 40)]
    
    headers = ["Threshold", "Accuracy", "Accuracy change", "Space saved"]
    results = []
    
    cnt = 0
    for pruning_threshold in pruning_thresholds:
        print("Testing threshold {}".format(pruning_threshold))
        res = [pruning_threshold]
        
        total_matrix_entries_pruned = 0

        cnt = 0
        
        for lidx, l in enumerate(layers):
            
            l['placeholder_W'] = tf.placeholder(tf.float32, l['W'].shape)
            l['placeholder_b'] = tf.placeholder(tf.float32, l['b'].shape)
            W, b = sess.run([l['W'], l['b']])
        
            prev_cnt = cnt
            cnt = 0
            for idx, rel in enumerate(final_relevances[::-1][lidx]):
                if(rel < pruning_threshold):
                    cnt += 1
                    W[:,idx] = 0
                    b[idx] = 0

            assign_W = tf.assign(l['W'], l['placeholder_W'])
            assign_b = tf.assign(l['b'], l['placeholder_b'])

            feed_d = {l['placeholder_W']: W, l['placeholder_b']: b}
            sess.run([assign_W, assign_b], feed_dict=feed_d)
            l['pruned'] = cnt
            entries = cnt * (int(l['W'].shape[0]) + 1)
            entries += prev_cnt * (int(l['W'].shape[1]))
            entries -= cnt * prev_cnt # Subtract overlap
            total_matrix_entries_pruned += entries
            print("Pruned {} neurons and {} entries".format(cnt, entries))

        final_acc = 0.
        for i in range(iterations):
            feed_d = {x: mnist.test.images[i*128:(i+1)*128], ground_truth: mnist.test.labels[i*128:(i+1)*128]}
            acc = sess.run(accuracy, feed_dict=feed_d)
            final_acc += acc

        final_acc /= iterations

        entries_fraction = (total_matrix_entries_pruned * 1. / (785 * 1200 + 1201 * 500 + 501 * 10 ) ) * 100
        acc_fraction = ((acc_before_pruning - final_acc) / acc_before_pruning) * 100
        
        res.append(final_acc)
        res.append(acc_fraction)
        res.append(entries_fraction)
        results.append(res)
        
        print("Done pruning neurons. Old acc: {} New acc: {}".format(acc_before_pruning, final_acc))
        print("Entries pruned: {}%, Loss of accuracy: {}%".format(entries_fraction, acc_fraction))

        print("--------------------")
    print(tabulate(results, headers=headers))
#     print_images_and_heatmaps(feed_d[x], relevances[-1], predictions, feed_d[ground_truth], 10)


Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Restoring parameters from ./models/fully-connected
Tensor("R1:0", shape=(128, 500), dtype=float32)
Tensor("R2:0", shape=(128, 1200), dtype=float32)
Iteration 0 out of 78
Iteration 1 out of 78
Iteration 2 out of 78
Iteration 3 out of 78
Iteration 4 out of 78
Iteration 5 out of 78
Iteration 6 out of 78
Iteration 7 out of 78
Iteration 8 out of 78
Iteration 9 out of 78
Iteration 10 out of 78
Iteration 11 out of 78
Iteration 12 out of 78
Iteration 13 out of 78
Iteration 14 out of 78
Iteration 15 out of 78
Iteration 16 out of 78
Iteration 17 out of 78
Iteration 18 out of 78
Iteration 19 out of 78
Iteration 20 out of 78
Iteration 21 out of 78
Iteration 22 out of 78
Iteration 23 out of 78
Iteration 24 out of 78
Iteration 25 out of 78
Iteration 26 out of 78
Iteration 27 out of 78
Iteration 28 out of 78
Iteration

In [2]:
print pruning_thresholds

[0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 10.5]
