In [1]:
##### from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
sys.path.append("..")
from modules.sequential import Sequential
from modules.linear import Linear
from modules.softmax import Softmax
from modules.relu import Relu
from modules.tanh import Tanh
from modules.avgpool import AvgPool
from modules.maxpool import MaxPool
from modules.convolution import Convolution
import modules.render as render
from modules.utils import Utils, Summaries, visualize
import input_data

import pylab

import tensorflow as tf
import numpy as np
%matplotlib inline

from matplotlib import pyplot as plt

# VARIABLES
batch_size = 128
learning_rate = 0.5
training_iterations = 2001
regulization = 0.001

def gen_image(ax, arr, title=''):
    two_d = (np.reshape(arr, (28, 28)) * 255).astype(np.uint8)
    ax.set_title(title)
    ax.imshow(two_d, interpolation='nearest', cmap=pylab.gray())


def print_images_and_heatmaps(xs, hs, yp, yl, iterations):
    
    y_preds = np.argmax(yp, axis=1)
    ground_truth = np.argmax(yl, axis=1)
    
    images = xs.reshape([batch_size,28,28])
    images = (images + 1)/2.0
    
    heatmaps = hs.reshape([batch_size,28,28])
    for i in range(iterations):
        print('%d: Label: %d , Prediction: %d , prob: %f' % (i, ground_truth[i], y_preds[i], yp[i,y_preds[i]]))
        fig = plt.figure()
        ax1 = fig.add_subplot(2,2,1, adjustable='box', aspect=1)
        ax2 = fig.add_subplot(2,1,1, adjustable='box', aspect=1)

        gen_image(ax1, images[i], title='Input')
        gen_image(ax2, heatmaps[i], title='Heatmap')

        plt.show()


# input dict creation as per tensorflow source code
def feed_dict(mnist, train):    
    if train:
        xs, ys = mnist.train.next_batch(batch_size)
    else:
        xs, ys = mnist.test.next_batch(batch_size)
    return (2*xs)-1, ys

mnist = input_data.read_data_sets('data', one_hot=True)
    
with tf.Session() as sess:
    
    # GRAPH
    net = Sequential([Linear(input_dim=784,output_dim=1200, act ='relu', batch_size=batch_size),
                     Linear(500, act ='relu'),
                     Linear(10, act ='linear'),
                     Softmax()])
    
    x = tf.placeholder(tf.float32, [batch_size, 784], name='x-input')
    ground_truth = tf.placeholder(tf.float32, [batch_size, 10], name='y-input')

    y_pred = net.forward(x)
    y_pred = tf.identity(y_pred, name="y_pred")
    
    correct_prediction = tf.equal(tf.argmax(ground_truth, axis=1), tf.argmax(y_pred, axis=1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

    # Calculating loss
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth, logits=y_pred))
    
    for m in net.modules:
        if hasattr(m, 'weights'):
            loss += tf.nn.l2_loss(m.weights) * regulization

#     trainer = net.fit(output=y_pred,ground_truth=y_labels,loss='softmax_crossentropy',optimizer='adagrad', opt_params=[learning_rate]).train
    trainer = net.fit(output=y_pred,ground_truth=ground_truth,loss=loss,optimizer='grad_descent', opt_params=[learning_rate]).train
    
    tf.global_variables_initializer().run()
    
    for i in range(training_iterations):        
        feed_d = feed_dict(mnist, True)
        d = {x: feed_d[0], ground_truth: feed_d[1]}
        _, pred, acc = sess.run(
          [trainer, y_pred, accuracy], feed_dict=d)
        if i % 200 == 0:
            
            feed_d = feed_dict(mnist, False)
            d = {x: feed_d[0], ground_truth: feed_d[1]}
            acc = sess.run(accuracy, feed_dict=d)
            print(acc)
    
    # RELEVANCE
    relevance_layerwise = []
    R = y_pred
    for idx, layer in enumerate(net.modules[::-1]):
        R = net.lrp_layerwise(layer, R, 'alphabeta', 1.0)
        R = tf.identity(R, name="R{}".format(idx))
        relevance_layerwise.append(R)
    
    if not os.path.exists('./models'):
        os.system('mkdir models')
    
    os.system('rm ./models/*')

    saver = tf.train.Saver()
    saver.save(sess, 'models/fully-connected-fake')
    
    feed_d = feed_dict(mnist, False)
    d = {x: feed_d[0], ground_truth: feed_d[1]}
    
    y_pred, rels, acc = sess.run([y_pred, relevance_layerwise, accuracy], feed_dict=d)
    print("------------------")
    print("Accuracy on final test set: %f" % acc)    
    print_images_and_heatmaps(d[x], rels[-1], y_pred, d[ground_truth], 10)
    
        


Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
Forward Pass ... 
------------------------------------------------- 
linear_1:: [128, 784]
linear_1/weights
linear_1/biases
linear_2:: [128, 1200]
linear_2/weights
linear_2/biases
linear_3:: [128, 500]
linear_3/weights
linear_3/biases
softmax_4:: [128, 10]

------------------------------------------------- 
0.117188
0.546875
0.835938


KeyboardInterrupt: 