In [1]:
"""
Transfer learning in TensorFlow using deep neural network.

Train initial model on MNIST digits 0-4. Use pre-trained hidden
layers to train new model on a very small subset of digits 5-9 images. 
Specifically, freeze the bottom layer from the 0-4 digit model and allow
higher layers to train for 5-9 model.

DNN includes the following:
-5 layers
-ELU activation function
-He initialization
-batch normalization
-dropout

Convolutional neural nets are better for image classification but purpose of 
this is to demonstrate transfer learning as well as above items that improve
training of deep neural nets.
"""

import copy
import os

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


# model parameters
N_INPUTS = 28 * 28
N_NEURONS_BOTTOM = 200  # neurons in bottom layer
N_NEURONS_OTHER = 50  # neurons in other layers
N_OUTPUTS = 5  # digits 0-4 and 5-9
KEEP_PROB = 0.5  # dropout rate
BATCH_NORM_MOM = 0.9  # momentum for batch normalization
# training parameters
LEARNING_RATE = 0.01
N_EPOCHS_LOW = 25  # number of epochs to train 0-4
N_EPOCHS_HIGH = 100  # number of epochs to train 5-9
BATCH_SIZE = 100

SAVE_PATH_LOW = 'saved/low_digits/'
DATA_PATH = 'data/'


tf.reset_default_graph()

class DeepNN(object):
    """Build deep neural network to classify MNIST digits."""
    def __init__(self, learning_rate, n_neurons_bottom, 
                 n_neurons_other, n_inputs, n_outputs, keep_prob):
        self.lr = learning_rate
        self.n_neurons_bottom = n_neurons_bottom
        self.n_neurons_other = n_neurons_other
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.keep_prob = keep_prob

    def _create_placeholders(self):
        """Create placeholders for inputs, outputs, and is_training.
        is_training = True --> apply dropout during training
        is_training = False --> don't apply dropout during testing
        """
        with tf.name_scope('data'):
            self.X = tf.placeholder(tf.float32, shape=(None, self.n_inputs), 
                                    name='X')
            self.y = tf.placeholder(tf.int64, shape=(None), name='y')
            self.is_training = tf.placeholder_with_default(False, 
                                                           shape=(), 
                                                           name='is_training')
    
    def _create_layer(self, prior_layer, n_neurons, name):
        """Create individual layer in neural network with he initialization, 
        batch normalization, and dropout.
        """
        # He initialialization for variables
        he_init = tf.contrib.layers.variance_scaling_initializer()
        # fully connected layer with he initialization
        hidden = tf.layers.dense(prior_layer, 
                                 n_neurons, 
                                 kernel_initializer=he_init, 
                                 name=name)
        # apply batch normalization
        bn = tf.layers.batch_normalization(hidden, 
                                           training=self.is_training, 
                                           momentum=BATCH_NORM_MOM)
        # use elu activation function
        bn_act = tf.nn.elu(bn)
        # apply dropout after batch normalization
        hidden_drop = tf.layers.dropout(bn_act, 
                                        KEEP_PROB, 
                                        training=self.is_training)
        return hidden_drop
        
    def _create_dnn(self):
        """Create deep neural network."""
        with tf.name_scope('dnn'):
            # apply dropout to inputs
            X_drop = tf.layers.dropout(self.X, self.keep_prob, 
                                       training=self.is_training)
            # create hidden layers
            hidden1 = self._create_layer(X_drop, self.n_neurons_bottom, 'hidden1')
            hidden2 = self._create_layer(hidden1, self.n_neurons_other, 'hidden2')
            hidden3 = self._create_layer(hidden2, self.n_neurons_other, 'hidden3')
            hidden4 = self._create_layer(hidden3, self.n_neurons_other, 'hidden4')
            # fully connected layer at end to compute outputs
            self.logits = tf.layers.dense(hidden4, self.n_outputs, name='outputs')

    def _create_loss(self):
        """Create cross entropy loss function."""
        with tf.name_scope('loss'):
            self.loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, 
                                                               logits=self.logits), 
                name='loss')
            
    def _create_optimizer(self):
        """Define optimizer to minimize loss."""
        with tf.name_scope('optimizer'):
            self.optimizer = tf.train.AdamOptimizer(self.lr)
            # initially train_vars set equal to all trainable variables
            # this is modified when model is restored
            self.train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            # only train/update variables in train_vars
            self.train_op = self.optimizer.minimize(self.loss, 
                                                    var_list=self.train_vars)
            
    def _create_eval(self):
        """Define loss evaluation metrics (accuracy)."""
        with tf.name_scope('eval'):
            correct = tf.nn.in_top_k(self.logits, self.y, 1)
            self.accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            
    def build_model(self):
        """Build graph for deep neural network."""
        self._create_placeholders()
        self._create_dnn()
        self._create_loss()
        self._create_optimizer()
        self._create_eval()
        self.init = tf.global_variables_initializer()

        
def train_model(model, mnist, n_epochs, batch_size, save_path, restore=False):
    """Train model on MNIST data."""
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        if restore:
            # restore the bottom hidden layer from saved 0-4 digit model 
            # which will be freezed during training
            reuse_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 
                                           scope='hidden[1]')
            reuse_vars_dict = dict([(var.op.name, var) for var in reuse_vars])
            restore_saver = tf.train.Saver(reuse_vars_dict) 
            
            # initialize variables and restore model
            sess.run(model.init)
            restore_saver.restore(sess, SAVE_PATH_LOW)
            
            # we allow hidden layers 2-4 and the outputs to train
            model.train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 
                                                scope='hidden[234]|outputs')
            # redefine train_op so it knows which variables it can update
            model.train_op = model.optimizer.minimize(model.loss, 
                                                      var_list=model.train_vars)
        else:
            # initialize variables for 0-4 digit model
            sess.run(model.init)

        for epoch in range(n_epochs):
            for iteration in range(mnist.train.num_examples // batch_size):
                # get random batch
                X_batch, y_batch = mnist.train.next_batch(batch_size)
                sess.run(model.train_op, feed_dict={model.is_training: True,
                                                    model.X: X_batch, 
                                                    model.y: y_batch})          
            if epoch % 10 == 0:
                # print test accuracy every 10 epochs
                accuracy_val = sess.run(model.accuracy, 
                                        feed_dict={model.is_training: False,
                                                   model.X: mnist.test.images,
                                                   model.y: mnist.test.labels})
                print('Epoch {0}, Test accuracy: {1:.3f}'.format(epoch, accuracy_val))
        
        print('\nSaving model...')
        if not restore and save_path: # save low digits model
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            saver.save(sess, save_path)

            
def mnist_subset_0_4(mnist, cutoff):
    """Create new MNIST object for low digits (0-4)."""
    index_train = mnist.train.labels <= cutoff
    index_test = mnist.test.labels <= cutoff
        
    # create new MNIST train dataset for digits 0-4
    mnist_sub = copy.deepcopy(mnist)
    mnist_sub.train._images = mnist.train.images[index_train == True]
    mnist_sub.train._labels = mnist.train.labels[index_train == True]
    mnist_sub.train._num_examples = len(mnist_sub.train.labels)
    
    # create new MNIST test dataset for digits 0-4
    mnist_sub.test._images = mnist.test.images[index_test == True]
    mnist_sub.test._labels = mnist.test.labels[index_test == True]
    mnist_sub.test._num_examples = len(mnist_sub.test.labels)
    
    return mnist_sub


def mnist_subset_5_9(mnist, cutoff, n_samples):
    """Create new MNIST object for high digits (5-9). """
    index_train = mnist.train.labels > cutoff
    index_test = mnist.test.labels > cutoff
        
    # create small (n_samples) MNIST train dataset for digits 5-9 
    mnist_sub = copy.deepcopy(mnist)
    mnist_sub.train._images = mnist.train.images[index_train == True][:n_samples, :]
    # shift labels to 0-4 as tf.nn.sparse_softmax_cross_entropy_with_logits requires
    # them in this range. We retrain ouput layer so not an issue these labels are now
    # the same as 0-4.
    mnist_sub.train._labels = mnist.train.labels[index_train == True] - 5
    mnist_sub.train._labels = mnist_sub.train.labels[:n_samples]
    mnist_sub.train._num_examples = len(mnist_sub.train.labels)
    
    mnist_sub.test._images = mnist.test.images[index_test == True]
    mnist_sub.test._labels = mnist.test.labels[index_test == True] - 5
    mnist_sub.test._num_examples = len(mnist_sub.test.labels)
    
    return mnist_sub
    
    
def split_data(mnist, cutoff):
    """Split datasets into MNIST digits 0-4 and MNIST digits 5-9."""
    # create MNIST dataset for images with labels 0-4
    mnist_0_4 = mnist_subset_0_4(mnist, cutoff)
    # create MNIST dataset for images with labels 5-9
    # we only keep the first 500 images for training to demonstate
    # transfer learning on a small dataset
    mnist_5_9 = mnist_subset_5_9(mnist, cutoff, 500)
    
    return mnist_0_4, mnist_5_9

In [2]:
# read MNIST dataset
mnist_all = input_data.read_data_sets(DATA_PATH)
# split into 0-4 and 5-9 digit datasets
mnist_0_4, mnist_5_9 = split_data(mnist_all, 4)

Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz


First train deep neural network on entire MNIST digits 0-4 dataset.

In [3]:
# Create DeepNN model
model = DeepNN(LEARNING_RATE, N_NEURONS_BOTTOM, 
               N_NEURONS_OTHER, N_INPUTS, N_OUTPUTS,  KEEP_PROB)
model.build_model()
# Train DNN on digist 0-4
train_model(model, mnist_0_4, N_EPOCHS_LOW, 
            BATCH_SIZE, SAVE_PATH_LOW, restore=False)

Epoch 0, Test accuracy: 0.958
Epoch 10, Test accuracy: 0.985
Epoch 20, Test accuracy: 0.983

Saving model...


#### Transfer Learning

We use the model trained on digits 0-4 to train on digits 5-9. Rather than training on the entire 5-9 digit dataset, we select only 500 images (approx. 100 images per digit) to train on. This demonstrates pre-training on a semi-related large dataset, then actually training on the desired, yet much smaller dataset. For image classification tasks, deep neural networks tend to learn lower level features in the bottom layers and higher level features in the top layers. So we can expect the features learned in the bottom layer(s) for digits 0-4 to be similar to that for 5-9. We freeze the bottom layer from the 0-4 digit model while training the 5-9 digit model and allow the other layers to update.

In [4]:
train_model(model, mnist_5_9, N_EPOCHS_HIGH, BATCH_SIZE, save_path=None, restore=True)

INFO:tensorflow:Restoring parameters from saved/low_digits/
Epoch 0, Test accuracy: 0.561
Epoch 10, Test accuracy: 0.784
Epoch 20, Test accuracy: 0.837
Epoch 30, Test accuracy: 0.832
Epoch 40, Test accuracy: 0.873
Epoch 50, Test accuracy: 0.847
Epoch 60, Test accuracy: 0.853
Epoch 70, Test accuracy: 0.839
Epoch 80, Test accuracy: 0.837
Epoch 90, Test accuracy: 0.851

Saving model...


For comparison, train on the small 5-9 digits dataset without pre-training on 0-4 digit dataset. This achieves much lower test set accuracy than the pre-trained model.

In [5]:
tf.reset_default_graph()

# Create DeepNN model
model = DeepNN(LEARNING_RATE, N_NEURONS_BOTTOM, 
               N_NEURONS_OTHER, N_INPUTS, N_OUTPUTS,  KEEP_PROB)
model.build_model()
# Train DNN on digist 0-4
train_model(model, mnist_5_9, N_EPOCHS_HIGH, BATCH_SIZE, save_path=None, restore=False)

Epoch 0, Test accuracy: 0.574
Epoch 10, Test accuracy: 0.684
Epoch 20, Test accuracy: 0.712
Epoch 30, Test accuracy: 0.701
Epoch 40, Test accuracy: 0.732
Epoch 50, Test accuracy: 0.706
Epoch 60, Test accuracy: 0.709
Epoch 70, Test accuracy: 0.717
Epoch 80, Test accuracy: 0.712
Epoch 90, Test accuracy: 0.723

Saving model...
