In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None

In [8]:
rng = np.random.RandomState()
crop = 1
ds_factor = 1.
eig_frac_cutoff = .001

def new_shape(crop, ds_factor):
    return int(np.around((28 - crop) / ds_factor))

def transform(X, crop=2, ds_factor=2.):
    Xp = tf.reshape(X, (-1, 28, 28))
    Xp = Xp[:, crop:28-crop, crop:28-crop]
    new = new_shape(crop, ds_factor)
    Xp = tf.image.resize_images(tf.expand_dims(Xp, axis=-1), (new, new))
    return tf.reshape(Xp, (-1, new**2))

In [9]:
mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data/', one_hot=True)
new_dim = new_shape(crop, ds_factor) **2
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
x_small = transform(x, crop, ds_factor)
p = tf.Variable(tf.zeros((new_dim * 10) + 10))
W = tf.reshape(p[:new_dim * 10], [new_dim, 10])
b = tf.reshape(p[new_dim * 10:], [10])

y = tf.matmul(x_small, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz


In [10]:
opt = tf.train.GradientDescentOptimizer(learning_rate=1.)
grad, var = opt.compute_gradients(cross_entropy, var_list=[p])[0]
grad2 = tf.reshape(grad, (grad.shape[0], 1))

# Meaningless outside pixel make Hessian not invertable, cropping would fix need for tf.eye
hess = tf.squeeze(tf.hessians(cross_entropy, p)) #+ .0001 * tf.eye((new_dim * 10) + 10)
e, v = tf.self_adjoint_eig(tf.expand_dims(hess, axis=0))
thresh = tf.reduce_sum(e) * eig_frac_cutoff
keep = tf.reduce_sum(tf.cast(tf.greater_equal(e, thresh), tf.int32))
ep = tf.squeeze(e)[-keep:]
vp = tf.squeeze(v)[:, -keep:]
inv_hess = tf.matmul(vp, (1. / tf.expand_dims(ep, axis=0)) * vp, transpose_b=True)
grad_prime = tf.squeeze(tf.matmul(inv_hess, grad2))
train_step = opt.apply_gradients([(grad_prime, var)])

In [11]:
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

In [12]:
plt.close()
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Train
print('training')
for ii in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(2000)
    if ii % 10 == 0:
        print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                  y_: mnist.test.labels}))
        print(sess.run(accuracy, feed_dict={x: batch_xs, y_: batch_ys}))
        plt.figure()
        n, e_float = sess.run([keep, e], feed_dict={x: batch_xs, y_: batch_ys})
        plt.plot(e_float.ravel())
        print(e_float.ravel()[-20:], n)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

print('test')
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                  y_: mnist.test.labels}))
plt.show()

training
0.098
0.101
[ 0.35628998  0.35629016  0.44745827  0.44745836  0.44745842  0.44745842
  0.44745848  0.4474586   0.4474586   0.44745874  0.44745928  4.13997984
  4.13998032  4.13998032  4.13998079  4.13998079  4.13998079  4.13998079
  4.13998175  4.1399827 ] 135
0.8889
0.8875
[ 0.08241394  0.09078948  0.09269261  0.09522659  0.09784698  0.10331086
  0.11912247  0.12157541  0.1431842   0.16414766  0.18712673  0.40612173
  0.43665218  0.52581567  0.65822381  0.7611351   0.90421444  1.24736309
  1.48996997  1.70660675] 119
0.8925
0.8945
[ 0.06274751  0.06377082  0.06820851  0.0706945   0.07280236  0.07880594
  0.09264552  0.09644875  0.10805403  0.11998767  0.14545912  0.29143661
  0.31474462  0.40359837  0.49199629  0.55618536  0.62603456  0.93768805
  1.16037571  1.36300516] 117
0.8969
0.8885
[ 0.05076982  0.05354552  0.05629497  0.0617409   0.0670227   0.07193024
  0.07549796  0.07816145  0.09253377  0.11558396  0.1295661   0.23748133
  0.25268039  0.33109638  0.43033043  0.4467

KeyboardInterrupt: 