# Train DNN on MNIST 0 to 4

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
n_inputs = 28 * 28
n_hidden = 100
n_outputs = 5

In [3]:
he_init = tf.contrib.layers.variance_scaling_initializer()
def dnn(inputs, n_hidden_layers=5, n_neurons=100, name=None,
        activation=tf.nn.relu, initializer=he_init):
    with tf.variable_scope(name, 'dnn'):
        for layer in range(n_hidden_layers):
            inputs = tf.layers.dense(inputs, n_neurons, activation=activation,
                                     kernel_initializer=initializer, name='hidden%d' % (layer + 1))
    return inputs

In [4]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=(None, n_inputs), name='X')
y = tf.placeholder(tf.int64, shape=(None), name='y')

In [5]:
dnn_outputs = dnn(X)
logits = tf.layers.dense(dnn_outputs, n_outputs, name='output')

In [6]:
y_proba = tf.nn.softmax(logits, name='probability')

In [7]:
learning_rate = 0.001
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(cross_entropy, name='loss')

In [8]:
optimizer = tf.train.AdamOptimizer(learning_rate)
training_op = optimizer.minimize(loss)

In [9]:
correct_preds = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32), name='accuracy')

In [10]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/tmp/mnist')

Extracting /tmp/mnist/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/t10k-labels-idx1-ubyte.gz


In [11]:
X_train1 = mnist.train.images[mnist.train.labels < 5]
y_train1 = mnist.train.labels[mnist.train.labels < 5]
X_valid1 = mnist.validation.images[mnist.validation.labels < 5]
y_valid1 = mnist.validation.labels[mnist.validation.labels < 5]
X_test1 = mnist.test.images[mnist.test.labels < 5]
y_test1 = mnist.test.labels[mnist.test.labels < 5]

In [12]:
n_epochs = 1000
batch_size = 50

best_score = 0.
max_checks_without_progress = 20
checks_without_progress = 0

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    
    for epoch in range(n_epochs):
        rnd_indices = np.random.permutation(len(X_train1))
        for rnd_indices in np.array_split(rnd_indices, len(X_train1) // batch_size):
            X_batch, y_batch = X_train1[rnd_indices], y_train1[rnd_indices]
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_score = sess.run(accuracy, feed_dict={X: X_valid1, y: y_valid1})
        if accuracy_score > best_score:
            print 'Epoch %d: Accuracy %.6f' % (epoch, accuracy_score)
            best_score = accuracy_score
            checks_without_progress = 0
            saver.save(sess, './deep_mnist_model.ckpt')
        else:
            checks_without_progress += 1
            if checks_without_progress > max_checks_without_progress:
                print 'Stopping early!'
                break
    print 'Best validation accuracy: %.6f' % best_score
    
with tf.Session() as sess:
    saver.restore(sess, './deep_mnist_model.ckpt')
    accuracy_score = sess.run(accuracy, feed_dict={X: X_test1, y: y_test1})
    print 'Test accuracy: %.6f' % accuracy_score

Epoch 0: Accuracy 0.979281
Epoch 1: Accuracy 0.989445
Epoch 4: Accuracy 0.990618
Epoch 7: Accuracy 0.992181
Epoch 13: Accuracy 0.992572
Epoch 15: Accuracy 0.993354
Epoch 18: Accuracy 0.993745
Epoch 19: Accuracy 0.995700
Stopping early!
Best validation accuracy: 0.995700
INFO:tensorflow:Restoring parameters from ./deep_mnist_model.ckpt
Test accuracy: 0.994162


# Train DNN on MNIST 5 to 9 by reusing existing model

In [22]:
X_train2 = mnist.train.images[mnist.train.labels > 4]
y_train2 = mnist.train.labels[mnist.train.labels > 4] - 5
X_valid2 = mnist.validation.images[mnist.validation.labels > 4]
y_valid2 = mnist.validation.labels[mnist.validation.labels > 4] - 5
X_test2 = mnist.test.images[mnist.test.labels > 4]
y_test2 = mnist.test.labels[mnist.test.labels > 4] - 5

In [29]:
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='hidden[12345]|output')
init_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='hidden5|output')
training_op = optimizer.minimize(loss, var_list=train_vars)

n_epochs = 1000
batch_size = 50

best_score = 0.
max_checks_without_progress = 20
checks_without_progress = 0

init = tf.global_variables_initializer()
with tf.Session() as sess:
    init.run()
    saver.restore(sess, './deep_mnist_model.ckpt')
    for var in train_vars:
        sess.run(var.initializer)
    for epoch in range(n_epochs):
        rnd_indices = np.random.permutation(len(X_train2))
        for rnd_indices in np.array_split(rnd_indices, len(X_train2) // batch_size):
            X_batch, y_batch = X_train2[rnd_indices], y_train2[rnd_indices]
            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
        accuracy_score = sess.run(accuracy, feed_dict={X: X_valid2, y: y_valid2})
        if accuracy_score > best_score:
            print 'Epoch %d: Accuracy %.6f' % (epoch, accuracy_score)
            best_score = accuracy_score
            checks_without_progress = 0
            saver.save(sess, './deep_mnist_model2.ckpt')
        else:
            checks_without_progress += 1
            if checks_without_progress > max_checks_without_progress:
                print 'Stopping early!'
                break
    print 'Best validation accuracy: %.6f' % best_score

INFO:tensorflow:Restoring parameters from ./deep_mnist_model.ckpt
Epoch 0: Accuracy 0.675676
Epoch 1: Accuracy 0.705160
Epoch 2: Accuracy 0.709255
Epoch 3: Accuracy 0.711302
Epoch 4: Accuracy 0.719492
Epoch 5: Accuracy 0.721949
Epoch 6: Accuracy 0.725635
Epoch 7: Accuracy 0.729730
Epoch 8: Accuracy 0.734234
Epoch 9: Accuracy 0.737101
Epoch 10: Accuracy 0.740786
Epoch 11: Accuracy 0.744881
Epoch 14: Accuracy 0.748157
Epoch 17: Accuracy 0.748567
Epoch 18: Accuracy 0.751024
Epoch 19: Accuracy 0.758395
Epoch 26: Accuracy 0.759623
Epoch 27: Accuracy 0.761671
Epoch 29: Accuracy 0.762080
Epoch 30: Accuracy 0.764947
Epoch 33: Accuracy 0.766175
Epoch 41: Accuracy 0.767404
Epoch 45: Accuracy 0.771089
Epoch 54: Accuracy 0.772318
Epoch 65: Accuracy 0.775184
Stopping early!
Best validation accuracy: 0.775184
