In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys
import time

from six.moves import xrange
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist
import math

In [None]:
FLAGS = None

In [None]:
def placeholder_inputs(batch_size):
    images = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
    labels = tf.placeholder(tf.int32, shape=(batch_size))
    return images, labels

def fill_feed_dict(data_set, images_pl, labels_pl):
    images, labels = data_set.next_batch(FLAGS.batch_size, FLAGS.fake_data)
    feed_dict = {
        images_pl: images,
        labels_pl: labels
    }
    return feed_dict

def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
    
    true_count = 0
    num_steps = data_set.num_examples
    num_examples = num_steps * FLAGS.batch_size
    for step in xrange(num_steps):
        feed_dict = fill_feed_dict(data_set, images_placeholder, labels_placeholder)
        true_count += sess.run(eval_correct, feed_dict = feed_dict)
        
    precision = float(true_count) / num_examples
    print ("Num examples: %d, true_count: %d, precision: %f" % (num_examples, true_count, precision))

In [None]:
def run_training():
    data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
    with tf.Graph().as_default():
        images, labels = placeholder_inputs(FLAGS.batch_size)
        logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
        train_loss = mnist.loss(logits, labels)
        train_op = mnist.training(train_loss, FLAGS.learning_rate)
        eval_correct = mnist.evaluation(logits, labels)
        
        init = tf.global_variables_initializer()
        sess = tf.Session()
        sess.run(init)
        
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            feed_dict = fill_feed_dict(data_sets.train, images, labels)
            sess.run([train_op], feed_dict = feed_dict)
            duration = time.time() - start_time
            # print (duration)
            
            if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                print ("Training data eval:")
                do_eval(sess,
                        eval_correct,
                        images,
                        labels,
                        data_sets.train)
                print ("Validation data eval:")
                do_eval(sess,
                        eval_correct,
                        images,
                        labels,
                        data_sets.validation)
                print ("Test data eval:")
                do_eval(sess,
                        eval_correct,
                        images,
                        labels,
                        data_sets.test)

In [None]:
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=20000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/input_data'),
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  FLAGS, unparsed = parser.parse_known_args()
  run_training()