# MNIST LSTM Example
Adapted from [github: TensorFlow Examples](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/recurrent_network.ipynb).

To classify images using a reccurent neural network, we consider every image
row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then
handle 28 sequences of 28 steps for every sample.

In [None]:
import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
# from tensorflow.models.rnn import rnn, rnn_cell
import tensortools as tt

In [None]:
BATCH_SIZE = 128
NUM_EXAMPES = 100000
DROPOUT = 0.5
REG = 5e-4
LEARGNING_RATE = 0.001
DISPLAY_STEP = 10
N_INPUT = 28
N_STEPS = 28
N_HIDDEN = 128
N_CLASSES = 10

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

In [None]:
g = tf.Graph()

In [None]:
def RNN(x, weights, biases):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    x = tf.reshape(x, [-1, N_INPUT])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.split(0, N_STEPS, x)

    # Define a lstm cell with tensorflow
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(N_HIDDEN,
                                             forget_bias=1.0,
                                             state_is_tuple=True)

    # Get lstm cell output
    outputs, states = tf.nn.rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights) + biases

In [None]:
with g.as_default():
    x = tf.placeholder(tf.float32, [None, N_STEPS, N_INPUT], "X")
    y_ = tf.placeholder(tf.float32, [None, N_CLASSES], "Y_")

    weights = tf.get_variable("weights",
                    shape=[N_HIDDEN, N_CLASSES],
                    dtype=tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer())
    biases = tf.get_variable("biases",
                    shape=[N_CLASSES])

    pred = RNN(x, weights, biases)

In [None]:
with g.as_default():
    with tf.name_scope("Train"):
        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y_))
        optimizer = tf.train.AdamOptimizer(learning_rate=LEARGNING_RATE).minimize(cost)

    with tf.name_scope("Accuracy"):
        correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))    

In [None]:
with g.as_default():
    # Launch the graph
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        
        tt.visualization.show_graph(sess.graph_def)
        
        step = 1
        # Keep training until reach max iterations
        while step * BATCH_SIZE < NUM_EXAMPES:
            batch_x, batch_y = mnist.train.next_batch(BATCH_SIZE)
            # Reshape data to get 28 seq of 28 elements
            batch_x = batch_x.reshape((BATCH_SIZE, N_STEPS, N_INPUT))
            # Run optimization op (backprop)
            sess.run(optimizer, feed_dict={x: batch_x, y_: batch_y})
            if step % DISPLAY_STEP == 0:
                # Calculate batch accuracy
                acc = sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y})
                # Calculate batch loss
                loss = sess.run(cost, feed_dict={x: batch_x, y_: batch_y})
                print "Iter " + str(step*BATCH_SIZE) + ", Minibatch Loss= " + \
                      "{:.6f}".format(loss) + ", Training Accuracy= " + \
                      "{:.5f}".format(acc)
            step += 1
        print "Optimization Finished!"

        # Calculate accuracy for 128 mnist test images
        test_len = 128
        test_data = mnist.test.images[:test_len].reshape((-1, N_STEPS, N_INPUT))
        test_label = mnist.test.labels[:test_len]
        print "Testing Accuracy:", \
            sess.run(accuracy, feed_dict={x: test_data, y_: test_label})