# Pixel level rnn on MNIST
This is a simple exercise to use RNN model for mnist classification. The model scan MNIST images pixel by pixel to classify the digits.

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials import mnist

## Define experiment params

In [2]:
in_dim = 1  # input dimensionality is 1 (1 pixel at a time)
T = 28*28  # times is the total number of pixels
num_classes = 10 
num_units = 64  # number of units in network
max_iter = 10000
init_lr = 0.001  # initial learning rate
batch_size = 512

## data provider

In [3]:
class mnist_provider(object):
  """MNIST data provider."""
  def __init__(self, data_directory, split="train"):
    mnist_data = mnist.input_data.read_data_sets(data_directory, one_hot=True)
    if split == "train":
      self.mnist_data = mnist_data.train
    elif split == "valid":
      self.mnist_data = mnist_data.validation
    elif split == "test":
      self.mnist_data = mnist_data.test

  def next_batch(self, batch_size):
    images, one_hot_labels = self.mnist_data.next_batch(batch_size)
    images = np.reshape(images, [-1, 28, 28, 1], order='C')
    return images, one_hot_labels

## RNN Model

In [4]:
class RNN(object):
  """Class for RNN cell.
  
  Args:
    weights: dictionary of weights as tensorflow variables.
    biases: dictionary of biases as tensorflow variables.
    outputs: rnn output pre logit (list of size number of time steps).
    state: final rnn state.
    cell: tensorflow cell obbject
    logits_list: list of logits at each time point.
  """
  def __init__(self, inputs, num_units, num_classes, activation=tf.nn.tanh):
    """Init function.
    
    Inputs:
      inputs: list of lenght time steps for rnn inputs.
      num_units: number of network units.
      num_classes: network output classes.
      activation: activation function to use for cell (default tanh).
    """
    
    self.weights = {}
    self.biases = {}
    with tf.variable_scope("rnn"):
      with tf.variable_scope("internal"):
        self.cell = tf.nn.rnn_cell.GRUCell(
          num_units=num_units, activation=activation)
        
        self.outputs, self.state = tf.contrib.rnn.static_rnn(
          cell=self.cell, inputs=inputs, dtype=tf.float32)
      with tf.variable_scope("output"):
        self.weights["out"] = tf.get_variable(name="w",
                        shape=(num_units, num_classes),
                        initializer=tf.truncated_normal_initializer(stddev=0.1))
        
        self.biases["out"] = tf.get_variable(name="b",
                        shape=(num_classes),
                        initializer=tf.zeros_initializer())
    for v in self.cell.trainable_variables:
      name = v.name
      if "gates/kernel" in name:
        self.weights["gates"] = v
      elif "gates/bias" in name:
        self.biases["gates"] = v

      elif "candidate/kernel" in name:
        self.weights["candidate"] = v
      elif "candidate/bias" in name:
        self.biases["candidate"] = v

    # compute logits at each time step
    self.logits_list = []
    for output in self.outputs:
      self.logits_list.append(
        tf.matmul(output, self.weights["out"]) + self.biases["out"]) 
    

## Build Graph

In [5]:
data_dir = "/Users/gamal/git_local_repo/playground/data/mnist"
data_provider_train = mnist_provider(data_dir, split='train')
data_provider_valid = mnist_provider(data_dir, split='valid')
data_provider_test = mnist_provider(data_dir, split='test')

tf.reset_default_graph()
g = tf.Graph()

with g.as_default():
  x = tf.placeholder(dtype=tf.float32, shape=(None, T, in_dim))
  one_hot_labels = tf.placeholder(dtype=tf.float32, shape=(None, num_classes))
  # convert in put to a list of len number of time points
  x_t = tf.unstack(x, T, 1)
  
  # model
  rnn_model = RNN(inputs=x_t, num_units=num_units, num_classes=num_classes)
  logits = rnn_model.logits_list[-1]
  
  # accuracy metric
  top1_op = tf.nn.in_top_k(logits, tf.argmax(one_hot_labels, 1), 1)
  accuracy = tf.reduce_mean(tf.cast(top1_op, dtype=tf.float32))
  

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/t10k-images-idx3-ubyte.gz
Extracting /Users/gamal/git_local_repo/playground/data/mnist/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Extracting /Users/gamal/git_local_repo/playground/data/mnist/train-images-idx3-ubyte.gz
Extracting /Users/gamal/git_local_repo/playground/data/mn

### Loss function

In [6]:
with g.as_default():
  loss = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                               labels=one_hot_labels)
  )
  opt = tf.train.AdamOptimizer(learning_rate=init_lr)
  
  train_op = opt.minimize(loss)
  
  init_op = tf.global_variables_initializer()

## Optimizing

In [7]:
with g.as_default():
  with tf.Session() as sess:
    sess.run(init_op)
    for i in xrange(max_iter):
      imgs, lbls = data_provider_train.next_batch(
        batch_size=batch_size)
      _, ls, acc = sess.run(
        (train_op, loss, accuracy),
        feed_dict={x: imgs.reshape((-1, 28*28, 1)),
                   one_hot_labels: lbls}
      )
      
      if i % 50 == 0:
        print(
          "iter %d: (train loss: %.3f) (train accuracy: %.3f)" 
          %(i, ls, acc))
    

iter 0: (train loss: 2.303) (train accuracy: 0.105)
iter 50: (train loss: 2.280) (train accuracy: 0.123)
iter 100: (train loss: 2.303) (train accuracy: 0.076)
iter 150: (train loss: 2.299) (train accuracy: 0.127)
iter 200: (train loss: 2.305) (train accuracy: 0.096)
iter 250: (train loss: 2.301) (train accuracy: 0.123)
iter 300: (train loss: 2.300) (train accuracy: 0.121)
iter 350: (train loss: 2.302) (train accuracy: 0.111)
iter 400: (train loss: 2.303) (train accuracy: 0.100)
iter 450: (train loss: 2.304) (train accuracy: 0.111)
iter 500: (train loss: 2.304) (train accuracy: 0.090)
iter 550: (train loss: 2.304) (train accuracy: 0.094)
iter 600: (train loss: 2.301) (train accuracy: 0.129)
iter 650: (train loss: 2.298) (train accuracy: 0.125)
iter 700: (train loss: 2.301) (train accuracy: 0.119)
iter 750: (train loss: 2.302) (train accuracy: 0.107)
iter 800: (train loss: 2.303) (train accuracy: 0.158)
iter 850: (train loss: 2.299) (train accuracy: 0.102)
iter 900: (train loss: 2.290) (

iter 7500: (train loss: 0.085) (train accuracy: 0.969)
iter 7550: (train loss: 0.087) (train accuracy: 0.980)
iter 7600: (train loss: 0.067) (train accuracy: 0.980)
iter 7650: (train loss: 0.146) (train accuracy: 0.951)
iter 7700: (train loss: 0.109) (train accuracy: 0.967)
iter 7750: (train loss: 0.139) (train accuracy: 0.959)
iter 7800: (train loss: 0.100) (train accuracy: 0.969)
iter 7850: (train loss: 0.091) (train accuracy: 0.975)
iter 7900: (train loss: 0.135) (train accuracy: 0.969)
iter 7950: (train loss: 0.110) (train accuracy: 0.971)
iter 8000: (train loss: 0.117) (train accuracy: 0.963)
iter 8050: (train loss: 0.057) (train accuracy: 0.980)
iter 8100: (train loss: 0.092) (train accuracy: 0.973)
iter 8150: (train loss: 0.086) (train accuracy: 0.979)
iter 8200: (train loss: 0.065) (train accuracy: 0.984)
iter 8250: (train loss: 0.115) (train accuracy: 0.961)
iter 8300: (train loss: 0.076) (train accuracy: 0.979)
iter 8350: (train loss: 0.103) (train accuracy: 0.973)
iter 8400: