### Load Modules

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data

### Define Model Functions

In [2]:
def init_weights(shape, name):
    return tf.Variable(tf.random_normal(shape, stddev=0.01), name=name)

def logistic_model(X, w):
    return tf.matmul(X, w)

def mlp_model(X, w_h, w_o):
    h = tf.nn.sigmoid(tf.matmul(X, w_h))
    return tf.matmul(h, w_o)

### Load Dataset

In [3]:
mnist = input_data.read_data_sets("./MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
# or batch_X, batch_Y = mnist.train.next_batch(64)

Extracting ./MNIST_data/train-images-idx3-ubyte.gz
Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz


### Define Model

In [5]:
X = tf.placeholder("float", [None, 784], name="X")
Y = tf.placeholder("float", [None, 10], name="Y")

#model = 'Logistic' 
model = 'MLP'

if model == 'Logistic':
    w = init_weights([784, 10], "w")
    py_x = logistic_model(X, w)
    tf.summary.histogram("w_summary", w)
else :
    w_h = init_weights([784, 625], "w_h")
    w_o = init_weights([625, 10], "w_o")
    py_x = mlp_model(X, w_h, w_o)
    tf.summary.histogram("w_h_summary", w_h)
    tf.summary.histogram("w_o_summary", w_o)
    
with tf.name_scope("cost"):
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
    train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
    tf.summary.scalar("cost", cost)
    
with tf.name_scope("accuracy"):
    correct_pred = tf.equal(tf.argmax(Y, 1), tf.argmax(py_x, 1))
    acc_op = tf.reduce_mean(tf.cast(correct_pred, "float"))
    tf.summary.scalar("accuracy", acc_op)

### Run Training

In [6]:
saver = tf.train.Saver()
model_path = './models/'
if not os.path.exists(model_path):
    os.makedirs(model_path)

with tf.Session() as sess:
    writer = tf.summary.FileWriter("./logs/nn_logs", sess.graph)
    merged = tf.summary.merge_all()
    
    tf.global_variables_initializer().run()
    print("Global Variable Initialization Done!!")

    for i in range(10):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

        summary, acc = sess.run([merged, acc_op], feed_dict={X: teX, Y: teY})
        writer.add_summary(summary, i)
        
        print("Iteration :", i+1, ", Test Accuracy :", acc)
    save_path = saver.save(sess, model_path)
    print("Model saved in :", save_path)

Global Variable Initialization Done!!
Iteration : 1 , Test Accuracy : 0.9073
Iteration : 2 , Test Accuracy : 0.9307
Iteration : 3 , Test Accuracy : 0.9448
Iteration : 4 , Test Accuracy : 0.9549
Iteration : 5 , Test Accuracy : 0.961
Iteration : 6 , Test Accuracy : 0.9665
Iteration : 7 , Test Accuracy : 0.9685
Iteration : 8 , Test Accuracy : 0.9706
Iteration : 9 , Test Accuracy : 0.9723
Iteration : 10 , Test Accuracy : 0.973
Model saved in : ./models/


### Keep Training by Loading Existing Models

In [7]:
with tf.Session() as sess:       
    tf.global_variables_initializer().run()
    saver.restore(sess, model_path)
    print("Model loaded from :", save_path)
    
    for i in range(2):
        for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

        summary, acc = sess.run([merged, acc_op], feed_dict={X: teX, Y: teY})
        writer.add_summary(summary, i)
        
        save_path = saver.save(sess, model_path)
        
        print("Iteration :", i+1, ", Test Accuracy :", acc)

Model loaded from : ./models/
Iteration : 1 , Test Accuracy : 0.9446
Iteration : 2 , Test Accuracy : 0.9545
