In [1]:
import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
def weight_variable(shape, name):
    initial = tf.truncated_normal(shape, stddev = 0.1)
    return tf.Variable(initial, name)

def bias_variable(shape, name):
    initial = tf.constant(0.1, shape = shape)
    return tf.Variable(initial, name)

In [3]:
n_input = 28 # MNIST data input (image shape: 28*28)
n_steps = 28 # steps
n_hidden = 128 # number of neurons in fully connected layer 
n_classes = 10 # (0-9 digits)

x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

weights = {
    "w_fc" : weight_variable([n_hidden, n_classes], "w_fc")
}
biases = {
    "b_fc" : bias_variable([n_classes], "b_fc") 
}

In [4]:
x_transpose = tf.transpose(x, [1, 0, 2])
print("x_transpose shape: %s" % x_transpose.get_shape())

x_transpose shape: (28, ?, 28)


In [5]:
x_reshape = tf.reshape(x_transpose, [-1, n_input])
print("x_reshape shape: %s" % x_reshape.get_shape())

x_reshape shape: (?, 28)


In [6]:
#x_split = tf.split(0, n_steps, x_reshape)
x_split = tf.split(x_reshape, n_steps, 0)
print("type of x_split: %s" % type(x_split))
print("length of x_split: %d" % len(x_split))
print("shape of x_split[0]: %s" % x_split[0].get_shape())

type of x_split: <class 'list'>
length of x_split: 28
shape of x_split[0]: (?, 28)


In [7]:
basic_rnn_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
#basic_rnn_cell = rnn.BasicRNNCell(n_hidden)
#h, states = rnn.rnn(basic_rnn_cell, x_split, dtype=tf.float32)
h, states = rnn.static_rnn(basic_rnn_cell, x_split, dtype=tf.float32)
print("type of outputs: %s" % type(h))
print("length of outputs: %d" % len(h))
print("shape of h[0]: %s" % h[0].get_shape())
print("type of states: %s" % type(states))

type of outputs: <class 'list'>
length of outputs: 28
shape of h[0]: (?, 128)
type of states: <class 'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple'>


In [8]:
h_fc = tf.matmul(h[-1], weights['w_fc']) + biases['b_fc']
y_ = h_fc

In [10]:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=h_fc,labels=y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)

In [11]:
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

In [12]:
batch_size = 100
init_op = tf.global_variables_initializer()
sess = tf.InteractiveSession()
sess.run(init_op)

variables_names =[v.name for v in tf.trainable_variables()]

In [16]:
for step in range(5000):
    batch_x, batch_y = mnist.train.next_batch(batch_size)
    batch_x = np.reshape(batch_x, (batch_size, n_steps, n_input))
    cost_train, accuracy_train, states_train, rnn_out = sess.run([cost, accuracy, states, h[-1]], feed_dict = {x: batch_x, y: batch_y})
    values = sess.run(variables_names)
    rnn_out_mean = np.mean(rnn_out)
    #w_rnn_mean = np.mean(rnn_out)
    for k,v in zip(variables_names, values):
        if k == 'RNN/BasicLSTMCell/Linear/Matrix:0':
            w_rnn_mean = np.mean(v)

    if step < 1500:
        if step % 100 == 0:
            print("step %d, loss %.5f, accuracy %.3f, mean of lstm weight %.5f, mean of lstm out %.5f" % (step, cost_train, accuracy_train, w_rnn_mean, rnn_out_mean))
    else:
        if step%1000 == 0: 
            print("step %d, loss %.5f, accuracy %.3f, mean of lstm weight %.5f, mean of lstm out %.5f" % (step, cost_train, accuracy_train, w_rnn_mean, rnn_out_mean))
    optimizer.run(feed_dict={x: batch_x, y: batch_y})

step 0, loss 0.07132, accuracy 0.990, mean of lstm weight -0.00770, mean of lstm out -0.00633
step 100, loss 0.12184, accuracy 0.960, mean of lstm weight -0.00770, mean of lstm out -0.00103
step 200, loss 0.08764, accuracy 0.970, mean of lstm weight -0.00770, mean of lstm out -0.00338
step 300, loss 0.02183, accuracy 0.990, mean of lstm weight -0.00770, mean of lstm out -0.00668
step 400, loss 0.00406, accuracy 1.000, mean of lstm weight -0.00770, mean of lstm out -0.00752
step 500, loss 0.00434, accuracy 1.000, mean of lstm weight -0.00770, mean of lstm out -0.01017
step 600, loss 0.02094, accuracy 0.990, mean of lstm weight -0.00770, mean of lstm out -0.00995
step 700, loss 0.01480, accuracy 1.000, mean of lstm weight -0.00770, mean of lstm out -0.00292
step 800, loss 0.01319, accuracy 1.000, mean of lstm weight -0.00770, mean of lstm out -0.01078
step 900, loss 0.04478, accuracy 0.980, mean of lstm weight -0.00770, mean of lstm out -0.00815
step 1000, loss 0.04359, accuracy 0.990, m

In [17]:
cost_test, accuracy_test = sess.run([cost, accuracy], feed_dict={x: np.reshape(mnist.test.images, [-1, 28, 28]), y: mnist.test.labels})
print("final loss %.5f, accuracy %.5f" % (cost_test, accuracy_test) )

final loss 0.06874, accuracy 0.98050
