In [1]:
import tensorflow as tf
import numpy as np
import os.path as op
import os
import shutil
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

In [2]:
data_dir = op.expanduser("~/data/mnist")
mnist = read_data_sets(data_dir, one_hot=True)
logs_dir = '/tmp/tensorflow_logs'

Extracting /home/ogrisel/data/mnist/train-images-idx3-ubyte.gz
Extracting /home/ogrisel/data/mnist/train-labels-idx1-ubyte.gz
Extracting /home/ogrisel/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /home/ogrisel/data/mnist/t10k-labels-idx1-ubyte.gz


In [3]:
tf.reset_default_graph()
sess = tf.Session()
dtype = tf.float32
learning_rate = 0.1


with tf.name_scope('input'):
    x = tf.placeholder(dtype=dtype, shape=[None, 784], name='x-input')
    y = tf.placeholder(dtype=dtype, shape=[None, 10], name='y-input')


with tf.name_scope('variables'):
    W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1,
                                        dtype=dtype),
                    name='W')
    tf.histogram_summary('weights', W)
    b = tf.Variable(tf.zeros(shape=(10,), dtype=dtype), name='b')
    tf.histogram_summary('biases', b)
    slow_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
    fast_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))
    angle = tf.matmul(tf.reshape(slow_direction, [1, -1]),
                      tf.reshape(fast_direction, [-1, 1]))[0, 0]
    tf.scalar_summary('angle', angle)


with tf.name_scope('model'):
    preactivations = tf.matmul(x, W) + b
    tf.histogram_summary('preactivations', preactivations)
    y_pred = tf.nn.softmax(preactivations)
    tf.histogram_summary('predicted_probabilities', y_pred)


with tf.name_scope('loss'):
    cross_entropy = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(preactivations, y),
        name='cross_entropy')


with tf.name_scope('accuracy'):
    with tf.name_scope('correct_prediction'):
        correct_prediction = tf.equal(tf.argmax(y, 1),
                                      tf.argmax(y_pred, 1))
    with tf.name_scope('correct_prediction'):
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))
    tf.scalar_summary('accuracy', accuracy)


with tf.name_scope('gradient_directions'):
    [gW, gb] = tf.gradients(cross_entropy, [W, b])
    gW_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)))
    tf.scalar_summary('gradient norm', gW_norm)
    gW_normed = tf.reshape(gW / (gW_norm + 1e-7), [-1])


def vec_normalize(vec):
    vec_norm = tf.sqrt(tf.reduce_sum(tf.square(vec)))
    return vec / (vec_norm + 1e-7)
    

with tf.name_scope('updates'):
    W_update = W.assign_add(-learning_rate * gW)
    b_update = b.assign_add(-learning_rate * gb)

    slow_rate = 0.01
    new_slow_dir = slow_rate * gW_normed + (1 - slow_rate) * slow_direction
    slow_dir_update = slow_direction.assign(vec_normalize(new_slow_dir))

    fast_rate = 0.1
    new_fast_dir = fast_rate * gW_normed + (1 - fast_rate) * fast_direction
    fast_dir_update = fast_direction.assign(vec_normalize(new_fast_dir))

summaries = tf.merge_all_summaries()
shutil.rmtree(logs_dir)
train_writer = tf.train.SummaryWriter(logs_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(logs_dir + '/test')


def data_dict(train=True, batch_size=128):
    """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
    if train:
        xs, ys = mnist.train.next_batch(batch_size)
    else:
        xs, ys = mnist.test.images, mnist.test.labels
    return {x: xs.astype(np.float32), y: ys.astype(np.float32)}


sess.run(tf.initialize_all_variables())

In [4]:
for i in range(1000):
    if i % 10 == 0:
        # Evaluate on test set
        test_summaries, test_acc, test_angle = sess.run(
            [summaries, accuracy, angle],
            feed_dict=data_dict(train=False))
        test_writer.add_summary(test_summaries, i)
        print("Accuracy on test: %0.3f" % test_acc)
        print(test_angle)

    else:
        # Evaluate on train mini_batch
        train_summaries, _, _, _, _ = sess.run(
            [summaries, W_update, b_update, slow_dir_update, fast_dir_update],
            feed_dict=data_dict(train=True))
        train_writer.add_summary(train_summaries, i)

Accuracy on test: 0.063
0.0
Accuracy on test: 0.504
0.927509
Accuracy on test: 0.659
0.83169
Accuracy on test: 0.716
0.789484
Accuracy on test: 0.758
0.7517
Accuracy on test: 0.783
0.762475
Accuracy on test: 0.798
0.750865
Accuracy on test: 0.816
0.713542
Accuracy on test: 0.829
0.715733
Accuracy on test: 0.835
0.722063
Accuracy on test: 0.845
0.713856
Accuracy on test: 0.845
0.725295
Accuracy on test: 0.853
0.746347
Accuracy on test: 0.859
0.722093
Accuracy on test: 0.856
0.766615
Accuracy on test: 0.864
0.739274
Accuracy on test: 0.865
0.754107
Accuracy on test: 0.868
0.764039
Accuracy on test: 0.871
0.767359
Accuracy on test: 0.873
0.762097
Accuracy on test: 0.873
0.748504
Accuracy on test: 0.875
0.748079
Accuracy on test: 0.875
0.803382
Accuracy on test: 0.878
0.780443
Accuracy on test: 0.878
0.754247
Accuracy on test: 0.878
0.764528
Accuracy on test: 0.881
0.767631
Accuracy on test: 0.881
0.783781
Accuracy on test: 0.883
0.793499
Accuracy on test: 0.885
0.806884
Accuracy on test: 