In [1]:
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import numpy as np
import tensorflow as tf

from tensorflow.contrib.framework import arg_scope
from tensorflow.contrib.layers import fully_connected, batch_norm
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorflow.contrib.layers import dropout

### load data, split to 0 to 4 and 5 to 9

In [2]:
def load_mnist():
    mnist = fetch_mldata("MNIST Original")
    # without normalizing input, bcs we expect to use batch normalization
    X, y = mnist.data.astype(np.float32), mnist.target.astype(np.int32)
    X, y = shuffle(X, y)
    return X, y

In [3]:
X, y = load_mnist()
X.shape, y.shape

((70000, 784), (70000,))

In [4]:
X4, X5 = X[y<=4], X[y>=5]
y4, y5 = y[y<=4], y[y>=5]-5
X4.shape, X5.shape, y4.shape, y5.shape

((35735, 784), (34265, 784), (35735,), (34265,))

In [5]:
def make_batch_generator(X, y, batch_size=64):
    i, n = 0, X.shape[0]
    while True:
        i %= n
        yield X[i:i+batch_size], y[i:i+batch_size]
        i += batch_size
        if i >= n: i = 0

### build a dnn for 0 to 4

In [76]:
n_inputs = 28 * 28
n_hiddens = [100] * 5
n_outputs = 5
batch_size = 64
n_epoches = 20

tf.reset_default_graph()

X = tf.placeholder(tf.float32, [None, n_inputs], name="X")
y = tf.placeholder(tf.int32, [None], name="y")
is_training = tf.placeholder(tf.bool, [], name="is_training")
bn_params = {
    "is_training": is_training,
    "updates_collections": None,
    "decay": 0.9,
    "scale": True
}
he_init = variance_scaling_initializer()
keep_prob = 0.5

with tf.name_scope("dnn"):
    with arg_scope([fully_connected],
                  activation_fn=tf.nn.elu,
                  weights_initializer=he_init,
                  normalizer_fn=batch_norm,
                  normalizer_params=bn_params):
        prev = X
        for i, n_hidden in enumerate(n_hiddens):
            h = fully_connected(prev, n_hidden, scope="hidden%i" % i)
#             h = dropout(h, keep_prob, is_training=is_training)
            prev = h
    logits = fully_connected(h, n_outputs, activation_fn=None, scope="output")
    
with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
    loss = tf.reduce_mean(xentropy)
    
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    train_op = optimizer.minimize(loss)
    
with tf.name_scope("eval"):
    match = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(match, tf.float32))
    
init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [77]:
[v.name for v in tf.global_variables()]

['hidden0/weights:0',
 'hidden0/BatchNorm/beta:0',
 'hidden0/BatchNorm/gamma:0',
 'hidden0/BatchNorm/moving_mean:0',
 'hidden0/BatchNorm/moving_variance:0',
 'hidden1/weights:0',
 'hidden1/BatchNorm/beta:0',
 'hidden1/BatchNorm/gamma:0',
 'hidden1/BatchNorm/moving_mean:0',
 'hidden1/BatchNorm/moving_variance:0',
 'hidden2/weights:0',
 'hidden2/BatchNorm/beta:0',
 'hidden2/BatchNorm/gamma:0',
 'hidden2/BatchNorm/moving_mean:0',
 'hidden2/BatchNorm/moving_variance:0',
 'hidden3/weights:0',
 'hidden3/BatchNorm/beta:0',
 'hidden3/BatchNorm/gamma:0',
 'hidden3/BatchNorm/moving_mean:0',
 'hidden3/BatchNorm/moving_variance:0',
 'hidden4/weights:0',
 'hidden4/BatchNorm/beta:0',
 'hidden4/BatchNorm/gamma:0',
 'hidden4/BatchNorm/moving_mean:0',
 'hidden4/BatchNorm/moving_variance:0',
 'output/weights:0',
 'output/biases:0',
 'train/beta1_power:0',
 'train/beta2_power:0',
 'train/hidden0/weights/Adam:0',
 'train/hidden0/weights/Adam_1:0',
 'train/hidden0/BatchNorm/beta/Adam:0',
 'train/hidden0/Ba

In [78]:
X_train, X_test, y_train, y_test = train_test_split(X4, y4, test_size=5000, random_state=1337)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
train_batches = make_batch_generator(X_train, y_train, batch_size)

(30735, 784) (5000, 784) (30735,) (5000,)


In [79]:
with tf.Session() as sess:
    init.run()

    for e in range(n_epoches):
        for b in range(X_train.shape[0] // batch_size + 1):
            X_batch, y_batch = next(train_batches)
            sess.run(train_op, feed_dict={X: X_batch, y: y_batch, is_training: True})
            if b % 1000 == 0:
                train_loss, train_acc = sess.run([loss, accuracy],
                            feed_dict={X: X_batch, y: y_batch, is_training: False})
                test_acc = sess.run(accuracy, feed_dict={X: X_test, y: y_test, is_training: False})
                print(train_loss, train_acc, test_acc)
        saver.save(sess, "mnist0to4/model", global_step=e)
    save_path = saver.save(sess, "mnist0to4/finalmodel.ckpt")

2.9519 0.390625 0.3306
0.0269714 0.984375 0.9832
0.0133622 0.984375 0.9872
0.0113916 1.0 0.988
0.00499017 1.0 0.9894
0.0150441 0.984375 0.9878
0.00771244 1.0 0.9884
0.00524801 1.0 0.986
0.00231306 1.0 0.9856
0.00125167 1.0 0.9858
0.0197477 0.984375 0.989
0.0129075 0.984375 0.99
0.0029601 1.0 0.9874
0.0151311 0.984375 0.9892
0.0240409 0.984375 0.9886
0.000307843 1.0 0.9892
0.000259688 1.0 0.9902
0.00123581 1.0 0.9902
0.000252889 1.0 0.989
0.0185084 0.984375 0.989


In [80]:
save_path

'mnist0to4/finalmodel.ckpt'

### transfer learning for 5 to 9

In [81]:
## recreate the graph
n_inputs = 28 * 28
n_hiddens = [100] * 5
n_outputs = 5
batch_size = 64
n_epoches = 10

tf.reset_default_graph()

X = tf.placeholder(tf.float32, [None, n_inputs], name="X")
y = tf.placeholder(tf.int32, [None], name="y")
is_training = tf.placeholder(tf.bool, [], name="is_training")
bn_params = {
    "is_training": is_training,
    "updates_collections": None,
    "decay": 0.9,
    "scale": True
}
he_init = variance_scaling_initializer()
keep_prob = 0.5

with tf.name_scope("dnn"):
    with arg_scope([fully_connected],
                  activation_fn=tf.nn.elu,
                  weights_initializer=he_init,
                  normalizer_fn=batch_norm,
                  normalizer_params=bn_params):
        prev = X
        for i, n_hidden in enumerate(n_hiddens):
            h = fully_connected(prev, n_hidden, scope="hidden%i" % i)
#             h = dropout(h, keep_prob, is_training=is_training)
            prev = h
    logits = fully_connected(h, n_outputs, activation_fn=None, scope="output")
    
with tf.name_scope("loss"):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
    loss = tf.reduce_mean(xentropy)
    
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="output|hidden[34]")
    print([v.name for v in train_vars])
    train_op = optimizer.minimize(loss, var_list=train_vars)
    
with tf.name_scope("eval"):
    match = tf.nn.in_top_k(logits, y, 1)
    accuracy = tf.reduce_mean(tf.cast(match, tf.float32))
    

saver = tf.train.Saver()
# saver = tf.train.import_meta_graph("mnist0to4/finalmodel.ckpt.meta")

['hidden3/weights:0', 'hidden3/BatchNorm/beta:0', 'hidden3/BatchNorm/gamma:0', 'hidden4/weights:0', 'hidden4/BatchNorm/beta:0', 'hidden4/BatchNorm/gamma:0', 'output/weights:0', 'output/biases:0']


In [83]:
X_train, X_test, y_train, y_test = train_test_split(X5, y5, test_size=5000, random_state=1337)
X_train, y_train = X_train[:1000], y_train[:1000] # for small dataset
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)
train_batches = make_batch_generator(X_train, y_train, batch_size)

(1000, 784) (5000, 784) (1000,) (5000,)


In [86]:
with tf.Session() as sess:
    saver.restore(sess, save_path, )
#     tf.global_variables_initializer().run()

    for e in range(n_epoches):
        for b in range(X_train.shape[0] // batch_size + 1):
            X_batch, y_batch = next(train_batches)
            sess.run(train_op, feed_dict={X: X_batch, y: y_batch, is_training: True})
            if b % 1000 == 0:
                train_loss, train_acc = sess.run([loss, accuracy],
                            feed_dict={X: X_batch, y: y_batch, is_training: False})
                test_acc = sess.run(accuracy, feed_dict={X: X_test, y: y_test, is_training: False})
                print(train_loss, train_acc, test_acc)

6.54107 0.3125 0.3768
0.818634 0.828125 0.811
0.365254 0.875 0.8494
0.268277 0.890625 0.8722
0.178556 0.953125 0.8876
0.145689 0.96875 0.8978
0.125442 0.984375 0.9078
0.112284 0.984375 0.9094
0.10106 0.96875 0.913
0.0909153 0.96875 0.9152
