In [1]:
import tensorflow as tf
import numpy as np

In [2]:
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

In [3]:
height = 28
width = 28
channels = 1
n_inputs = height * width

conv1_fmaps = 32
conv1_ksize = 3
conv1_stride = 1
conv1_pad = "SAME"

conv3_fmaps = 64
conv3_ksize = 3
conv3_stride = 2
conv3_pad = "SAME"

pool4_fmaps = conv3_fmaps

n_fc5 = 128

fc5_dropout_rate = 0.5

n_outputs = 10

reset_graph()

with tf.name_scope("inputs"):
    X = tf.placeholder(tf.float32, shape=[None, n_inputs], name="X")
    X_reshaped = tf.reshape(X, shape=[-1, height, width, channels])
    y = tf.placeholder(tf.int32, shape=[None], name="y")
    training = tf.placeholder_with_default(False, shape=[], name='training')

with tf.name_scope("model"):
    conv1 = tf.layers.conv2d(X_reshaped, filters=conv1_fmaps, kernel_size=conv1_ksize,
                             strides=conv1_stride, padding=conv1_pad, name="conv1")
    
    pool2 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
                           padding="VALID", name="pool2")
    conv3 = tf.layers.conv2d(pool2, filters=conv3_fmaps, kernel_size=conv3_ksize,
                             strides=conv3_stride, padding=conv3_pad, name="conv3")
    
    with tf.name_scope("pool4"):
        pool4=tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
                             padding="VALID", name="pool4")
        pool4_flat=tf.reshape(pool4, shape=[-1, pool4_fmaps * 3 * 3])

    with tf.name_scope("full_connect5"):
        fc5 = tf.layers.dense(pool4_flat, n_fc5, activation=tf.nn.relu)
        fc5_drop = tf.layers.dropout(fc5, fc5_dropout_rate,
                                     training=training)
        
    with tf.name_scope("output"):
        logits = tf.layers.dense(fc5_drop, n_outputs, name="output")
        y_prob = tf.nn.softmax(logits, name="y_prob")

with tf.name_scope("train"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
    loss = tf.reduce_mean(xentropy)
    loss_summary=tf.summary.scalar("loss", loss)
    optimizer = tf.train.AdamOptimizer()
    training_op = optimizer.minimize(loss)

with tf.name_scope("eval"):
    correct = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    accuracy_summary=tf.summary.scalar("accuracy", accuracy)

with tf.name_scope("init_and_saver"):
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

In [4]:
def get_model_params():
    gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    return {gvar.op.name: value for gvar, value in zip(gvars, tf.get_default_session().run(gvars))}

def restore_model_params(model_params):
    gvar_names = list(model_params.keys())
    assign_ops = {gvar_name: tf.get_default_graph().get_operation_by_name(gvar_name + "/Assign")
                  for gvar_name in gvar_names}
    init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
    feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
    tf.get_default_session().run(assign_ops, feed_dict=feed_dict)

In [5]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./mnist_data/")

Extracting ./mnist_data/train-images-idx3-ubyte.gz
Extracting ./mnist_data/train-labels-idx1-ubyte.gz
Extracting ./mnist_data/t10k-images-idx3-ubyte.gz
Extracting ./mnist_data/t10k-labels-idx1-ubyte.gz


In [6]:
n_epochs = 1000
batch_size = 50

best_loss_val = np.infty
check_interval = 500
checks_since_last_progress = 0
max_checks_without_progress = 20
best_model_params = None 

file_writer=tf.summary.FileWriter("./tf-logs", tf.get_default_graph())

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        for iteration in range(mnist.train.num_examples // batch_size):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})
            if iteration % check_interval == 0:
                loss_val, loss_summary_str = sess.run([loss, loss_summary],
                                             feed_dict={X: mnist.validation.images,
                                             y: mnist.validation.labels})
                file_writer.add_summary(loss_summary_str, 
                                        epoch * (mnist.train.num_examples // batch_size) + iteration)
                if loss_val < best_loss_val:
                    best_loss_val = loss_val
                    checks_since_last_progress = 0
                    best_model_params = get_model_params()
                else:
                    checks_since_last_progress += 1
        acc_train = sess.run(accuracy, feed_dict={X: X_batch, y: y_batch})
        acc_val, acc_val_summary = sess.run([accuracy, accuracy_summary],
                                            feed_dict={X: mnist.validation.images,
                                                       y: mnist.validation.labels})
        file_writer.add_summary(acc_val_summary, (epoch+1) * (mnist.train.num_examples // batch_size))
        print("Epoch {}, train accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}".format(
                  epoch, acc_train * 100, acc_val * 100, best_loss_val))
        if checks_since_last_progress > max_checks_without_progress:
            print("Early stopping!")
            break

    if best_model_params:
        restore_model_params(best_model_params)
    acc_test = accuracy.eval(feed_dict={X: mnist.test.images,
                                        y: mnist.test.labels})
    print("Final accuracy on test set:", acc_test)
    save_path = saver.save(sess, "./mnist_cnn_model"
    file_writer.close()

Epoch 0, train accuracy: 98.0000%, valid. accuracy: 97.3000%, valid. best loss: 0.090310
Epoch 1, train accuracy: 100.0000%, valid. accuracy: 98.2400%, valid. best loss: 0.066434
Epoch 2, train accuracy: 98.0000%, valid. accuracy: 98.4800%, valid. best loss: 0.056852
Epoch 3, train accuracy: 98.0000%, valid. accuracy: 98.6000%, valid. best loss: 0.048353
Epoch 4, train accuracy: 98.0000%, valid. accuracy: 98.8600%, valid. best loss: 0.043404
Epoch 5, train accuracy: 98.0000%, valid. accuracy: 98.7800%, valid. best loss: 0.043404
Epoch 6, train accuracy: 100.0000%, valid. accuracy: 98.7800%, valid. best loss: 0.042727
Epoch 7, train accuracy: 100.0000%, valid. accuracy: 98.8000%, valid. best loss: 0.039800
Epoch 8, train accuracy: 100.0000%, valid. accuracy: 98.7400%, valid. best loss: 0.039800
Epoch 9, train accuracy: 100.0000%, valid. accuracy: 98.7800%, valid. best loss: 0.039800
Epoch 10, train accuracy: 100.0000%, valid. accuracy: 98.9000%, valid. best loss: 0.039800
Epoch 11, trai