# Imports

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import h5py
import tensorflow as tf

# wget https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
import input_data

# Layers

In [2]:
def dense_layer(x, input_size, output_size, activation):
    W = tf.Variable(tf.truncated_normal([input_size, output_size], stddev=0.1), name='weight')
    b = tf.Variable(tf.constant(0.1, shape=[output_size]), name='bias')
    y = activation(tf.matmul(x, W) + b)
    return y

In [3]:
def highway_layer(x, size, activation, carry_bias=-1.0):
    W = tf.Variable(tf.truncated_normal([size, size], stddev=0.1), name='weight')
    b = tf.Variable(tf.constant(0.1, shape=[size]), name='bias')

    W_T = tf.Variable(tf.truncated_normal([size, size], stddev=0.1), name='weight_transform')
    b_T = tf.Variable(tf.constant(carry_bias, shape=[size]), name='bias_transform')

    H = activation(tf.matmul(x, W) + b, name='activation')
    T = tf.sigmoid(tf.matmul(x, W_T) + b_T, name='transform_gate')
    C = tf.sub(1.0, T, name="carry_gate")

    y = tf.add(tf.mul(H, T), tf.mul(x, C), 'y') # y = (H * T) + (x * C)
    return y

# Build Graph

In [4]:
sess, _ = tf.Session(), tf.Graph().as_default()

In [5]:
input_layer_size = 784
hidden_layer_size = 50 # use ~71 for fully-connected (plain) layers, 50 for highway layers
output_layer_size = 10

x = tf.placeholder("float", [None, input_layer_size], name="x")
y_ = tf.placeholder("float", [None, output_layer_size], name="y_")

In [6]:
layer_count = 100
carry_bias_init = -2.0

prev_y = None
y = None
for i in range(layer_count):
    with tf.name_scope("layer{0}".format(i)) as scope:
        if i == 0: # first, input layer
            prev_y = dense_layer(x, input_layer_size, hidden_layer_size, tf.nn.relu)
        elif i == layer_count - 1: # last, output layer
            y = dense_layer(prev_y, hidden_layer_size, output_layer_size, tf.nn.softmax)
        else: # hidden layers
            # prev_y = dense_layer(prev_y, hidden_layer_size, hidden_layer_size, tf.nn.relu)
            prev_y = highway_layer(prev_y, hidden_layer_size, tf.nn.relu, carry_bias=carry_bias_init)

In [7]:
with tf.name_scope("loss") as scope:
    loss = -tf.reduce_sum(y_ * tf.log(y), name='loss')

with tf.name_scope("train") as scope:
    global_step = tf.Variable(0, trainable=False, name='global_step')
    train_step = tf.train.GradientDescentOptimizer(1e-2, name="GradientDescent").minimize(loss, name="train_step", global_step=global_step)

with tf.name_scope("test") as scope:
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")

In [8]:
sess.run(tf.initialize_all_variables())

# Save Graph

This is important. In order to train the graph on Fomoro we must save it as a protobuf file.

In [9]:
# Write out the graph as a protobuf file
tf.train.write_graph(sess.graph_def, 'models/', 'highway.pb', as_text=False)

# Load Dataset

In [10]:
mnist = input_data.read_data_sets("../mnist/data", one_hot=True)

Extracting ../mnist/data/train-images-idx3-ubyte.gz
Extracting ../mnist/data/train-labels-idx1-ubyte.gz
Extracting ../mnist/data/t10k-images-idx3-ubyte.gz
Extracting ../mnist/data/t10k-labels-idx1-ubyte.gz


# Train

This is just for testing locally.

In [11]:
epochs = 2000
batch_size = 50
checkpoint_interval = 100

In [13]:
for i in range(epochs):
    batch_xs, batch_ys = mnist.train.next_batch(batch_size)

    if i % checkpoint_interval == 0:
        valid_accuracy = sess.run(accuracy, feed_dict={
            x: mnist.validation.images,
            y_: mnist.validation.labels
        })
        print("epoch %d, validation accuracy %g" % (i, valid_accuracy))

    sess.run(train_step, feed_dict={
        x: batch_xs,
        y_: batch_ys
    })

test_accuracy = sess.run(accuracy, feed_dict={
    x: mnist.test.images,
    y_: mnist.test.labels
})
print("test accuracy %g" % test_accuracy)

epoch 0, validation accuracy 0.9176
epoch 100, validation accuracy 0.9028
epoch 200, validation accuracy 0.9224
epoch 300, validation accuracy 0.9342
epoch 400, validation accuracy 0.918
epoch 500, validation accuracy 0.9454
epoch 600, validation accuracy 0.9452
epoch 700, validation accuracy 0.9138
epoch 800, validation accuracy 0.951
epoch 900, validation accuracy 0.954
epoch 1000, validation accuracy 0.9544
epoch 1100, validation accuracy 0.9554
epoch 1200, validation accuracy 0.9632
epoch 1300, validation accuracy 0.959
epoch 1400, validation accuracy 0.9624
epoch 1500, validation accuracy 0.9546
epoch 1600, validation accuracy 0.9628
epoch 1700, validation accuracy 0.9548
epoch 1800, validation accuracy 0.9628
epoch 1900, validation accuracy 0.9638
test accuracy 0.9575
