## Using auxiliary task to train MNIST classification
Details can be found in Chapter 11 of Aurélien Géron book "Hands-On Machine Learning with Scikit-Learn and TensorFlow"

- auxiliary task: train two dnn models and combine the output to tell whether two images are from the same digits.
- after that, reuse the one of the dnn model, extend it to train a mnist classification model on a much smaller dataset (5k images, around 500 for each digit)
- it shows my exploration of how to save/restore/reuse models in tensorflow(0.12)

In [1]:
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import numpy as np
import tensorflow as tf

from tensorflow.contrib.framework import arg_scope
from tensorflow.contrib.layers import fully_connected, batch_norm
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorflow.contrib.layers import dropout

### load mnist, split into 2 parts
- #1 for auxiliary training, same or not
- #2 5000 images, traditional mnist classification

In [2]:
def load_mnist():
    mnist = fetch_mldata("MNIST Original")
    # without normalizing input, bcs we expect to use batch normalization
    X, y = mnist.data.astype(np.float32), mnist.target.astype(np.int32)
#     X /= 255 # not needed if batch norm is used
    X, y = shuffle(X, y)
    return X, y

In [3]:
def make_batch_generator(X, y, batch_size=64):
    i, n = 0, X.shape[0]
    while True:
        i %= n
        yield X[i:i+batch_size], y[i:i+batch_size]
        i += batch_size
        if i >= n: i = 0

In [4]:
X, y = load_mnist()
X1, X2, y1, y2 = train_test_split(X, y, test_size=5000)

### train for identical image

In [5]:
# slow but good enough for a demo
def pair_batch_generator(X, y, batch_size=64):
    idx = np.arange(X.shape[0])
    while True:
        selected = np.random.choice(idx, batch_size)
        x1 = X[selected]
        x2 = []
        for i, label in enumerate(y[selected]):
            if i % 2 == 0:
                choice = np.random.choice(np.where(y==label)[0], 1)
            else:
                choice = np.random.choice(np.where(y!=label)[0], 1)
            x2.append(X[choice])
        x2 = np.concatenate(x2)
        yy = np.tile([0, 1], batch_size//2).astype(np.int32)
        yield x1, x2, yy
        
pair_batches = pair_batch_generator(X1, y1)
test_x1, test_x2, test_yy = next(pair_batch_generator(X1, y1, 1000))
test_x1.shape, test_x2.shape, test_yy.shape

((1000, 784), (1000, 784), (1000,))

In [6]:
n_inputs = 28 * 28
n_hiddens = [100] * 5
batch_size = 64
n_epoches = 20

tf.reset_default_graph()

x1 = tf.placeholder(tf.float32, [None, n_inputs], name="x1")
x2 = tf.placeholder(tf.float32, [None, n_inputs], name="x2")
yy = tf.placeholder(tf.int32, [None], name="yy")
is_training = tf.placeholder(tf.bool, [], name="is_training")
keep_prob = 0.5

bn_params = {
    "is_training": is_training,
    "decay": 0.9,
    "updates_collections": None,
    "scale": True
}
he_init = variance_scaling_initializer()

with tf.name_scope("dnn"):
    with arg_scope([fully_connected],
                  activation_fn=tf.nn.elu,
                  normalizer_fn=batch_norm,
                  normalizer_params=bn_params,
                  weights_initializer=he_init):
        prev1, prev2 = x1, x2
        for i, n_hidden in enumerate(n_hiddens):
            h1 = fully_connected(prev1, n_hidden,
                                 scope="dnn1/hidden%i"%i)
            h1 = dropout(h1, keep_prob, is_training=is_training)
            h2 = fully_connected(prev2, n_hidden,
                                 scope="dnn2/hidden%i"%i)
            h2 = dropout(h2, keep_prob, is_training=is_training)
            prev1, prev2 = h1, h2
        h = tf.concat(1, [h1, h2])
        hh = fully_connected(h, 100, scope="output-2")
        hh = dropout(hh, keep_prob, is_training=is_training)
        logits = fully_connected(hh, 2, activation_fn=None, scope="output")
#         logits = tf.squeeze(logits, axis=1)
        
with tf.name_scope("loss"):
#     cast_yy = tf.cast(yy, tf.float32)
#     xentropy = -cast_yy * tf.log(logits) - (1-cast_yy) * tf.log(1-logits)
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, yy)
    loss = tf.reduce_mean(xentropy)
    
with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer()
    train_op = optimizer.minimize(loss)
    
with tf.name_scope("eval"):
#     labels = tf.cast(logits >= 0.5, tf.int32)
#     match = tf.equal(labels, yy)
    match = tf.nn.in_top_k(logits, yy, 1)
    accuracy = tf.reduce_mean(tf.cast(match, tf.float32))
    
init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [None]:
with tf.Session() as sess:
    init.run()
#     saver.restore(sess, "./same_digit_model.ckpt")
    for i in range(n_epoches * 55000 // batch_size):
        b_x1, b_x2, b_yy = next(pair_batches)
        sess.run(train_op, feed_dict={x1: b_x1, x2: b_x2, yy: b_yy, is_training:True})
        if i % 1000 == 0:
            train_loss, train_acc = sess.run([loss, accuracy], 
                        feed_dict={x1: b_x1, x2: b_x2, yy: b_yy, is_training:False})
            test_acc = sess.run(accuracy,
                        feed_dict={x1: test_x1, x2: test_x2, yy: test_yy, is_training:False})
            print(train_loss, train_acc, test_acc)
    save_path = saver.save(sess, "same_digit_model.ckpt")

In [6]:
!ls -lh same_digit_model.ckpt*

-rw-r--r-- 1 dola dola 3.0M Feb 16 08:13 same_digit_model.ckpt.data-00000-of-00001
-rw-r--r-- 1 dola dola 4.8K Feb 16 08:13 same_digit_model.ckpt.index
-rw-r--r-- 1 dola dola 1.3M Feb 16 08:13 same_digit_model.ckpt.meta


### reconstruct dnn1 for mnist classification on dataset 2

In [6]:
X_train, X_test, y_train, y_test = train_test_split(X2, y2, test_size=0.15)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((4250, 784), (750, 784), (4250,), (750,))

In [7]:
batch_size = 32
n_epoches = 10

In [8]:
train_batches = make_batch_generator(X_train, y_train, batch_size)

In [23]:
# load the meta data into default graph
tf.reset_default_graph()
saver = tf.train.import_meta_graph("./same_digit_model.ckpt.meta")
default_graph = tf.get_default_graph()

# reuse dnn1 and extend it
n_outputs = 10
# so now you can read the existing variables from saved graph
X = default_graph.get_tensor_by_name("x1:0")
h = default_graph.get_tensor_by_name("dnn/dnn1/hidden4/Elu:0")
is_training = default_graph.get_tensor_by_name("is_training:0")
y = tf.placeholder(tf.int32, [None], "y")

he_init = variance_scaling_initializer()
bn_params = {
    "is_training": is_training,
    "decay": 0.9,
    "scale": True,
    "updates_collections": None
}
hh = fully_connected(X, 100, 
                         activation_fn=tf.nn.elu,
                         normalizer_fn=batch_norm,
                         normalizer_params=bn_params,
                         weights_initializer=he_init, 
                         scope="mnist_hidden")
hh = dropout(hh, keep_prob=0.5, is_training=is_training, scope="mnist_dropout")
logits = fully_connected(hh, n_outputs, 
                         activation_fn=None,
                         normalizer_fn=batch_norm,
                         normalizer_params=bn_params,
                         weights_initializer=he_init, 
                         scope="mnist_output")

xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y)
loss = tf.reduce_mean(xentropy)

optimizer = tf.train.AdamOptimizer()
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                               scope="mnist*")
train_op = optimizer.minimize(loss, var_list=train_vars)

match = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(match, tf.float32))

In [24]:
[v.name for v in train_vars]

['mnist_hidden/weights:0',
 'mnist_hidden/BatchNorm/beta:0',
 'mnist_hidden/BatchNorm/gamma:0',
 'mnist_output/weights:0',
 'mnist_output/BatchNorm/beta:0',
 'mnist_output/BatchNorm/gamma:0']

In [25]:
with tf.Session() as sess:
    tf.global_variables_initializer().run() # initialize other variables
    saver.restore(sess, "./same_digit_model.ckpt") # load existing variables
    for e in range(n_epoches):
        for i in range(5000 // batch_size):
            X_batch, y_batch = next(train_batches)
            sess.run(train_op, feed_dict={X: X_batch, y: y_batch, is_training:True})
            if i % 100 == 0:
                train_loss, train_acc = sess.run([loss, accuracy],
                            feed_dict={X: X_batch, y: y_batch, is_training:False})
                test_acc = sess.run(accuracy,
                            feed_dict={X: X_test, y: y_test, is_training: False})
                print(train_loss, train_acc, test_acc)

4.41069 0.1875 0.170667
0.785419 0.90625 0.86
0.673975 0.9375 0.864
0.613128 0.8125 0.884
0.72141 0.875 0.874667
0.641879 0.875 0.881333
0.594866 0.90625 0.885333
0.620639 0.90625 0.886667
0.499118 0.96875 0.893333
0.544461 0.875 0.886667
0.400586 0.96875 0.896
0.416376 0.96875 0.893333
0.512343 0.9375 0.890667
0.314586 1.0 0.902667
0.350421 0.96875 0.886667
0.347452 0.9375 0.894667
0.52806 0.875 0.896
0.404227 0.9375 0.892
0.249052 0.96875 0.908
0.354288 0.90625 0.909333
