In [1]:
'''
Train a classification CNN with transfer learning

This notebook outlines how to train a CNN classifier 
to discriminate between cell types by first initializing
the network with weights from a simulated data
classifier, as trained in 01_train_classification_model
'''

from motcube_preprocessing import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, CSVLogger
from keras.optimizers import SGD
import numpy as np

# Lanternfish models are stored in bestiary
# multicontextL classifies an arbitrary number of classes using stacked
# 3D-convolutional layers
from bestiary import multi_contextL as model_fcn
# Keras contains a function to load saved models in one-line!
from keras.models import load_model

ImportError: No module named 'motcube_preprocessing'

In [None]:
# Define a learning rate schedule function
# and set the optimizer
def lr_schedule(rate=0.01, decay=0.8):
    '''
    Generates a schedule function with exp decay

    alpha_new = alpha_init * decay_coeff^epoch

    Parameters
    ----------
    rate : float.
        learning rate.
    decay : float.
        decay coefficient
    '''
    def sched(epoch):
        return (rate * (decay**np.int(epoch)))
    return sched

sched = lr_schedule(rate = 0.005, decay = 0.8)
sgd = SGD(momentum = 0.5)

In [None]:
# Set directory and file information

# path to pretrained model
trained_model_path = 'nets/20170315_multiclass_bin_disk25.h5'

# paths to new cell data
train_dir = '/media/jkimmel/HDD0/myctophid/cell_data/train'
val_dir = '/media/jkimmel/HDD0/myctophid/cell_data/val'
nb_classes = 2
batch_size = 8
cube_size = (216,216,76)
file_name_save = 'mef_wt_v_mr_no_transfer.h5'

In [None]:
# Set training callbacks for keras .fit_generator() 
#
# ModelCheckpoint saves model with lowest val loss
# LearningRateScheduler applies the scheduling function
# EarlyStopping halts training if val loss stops improving
# CSVLogger writes training data to a CSV
callbacks = [ModelCheckpoint(file_name_save, monitor = 'val_loss', verbose = 0, save_best_only = True, mode = 'auto'),
    LearningRateScheduler(sched),
    EarlyStopping(monitor='val_loss', patience=3),
    CSVLogger(filename=file_name_save[:-3]+'_train_log.csv', separator=',')]

In [None]:
# Instantiate data generators
# NOTE: here we use the horizontal_flip and vertical_flip
# features of MotcubeDataGenerator to diversify our small
# sample set with non-destructive permutations
mcgen = MotcubeDataGenerator(horizontal_flip=True, vertical_flip=True)
mc_generator = mcgen.flow_from_directory(train_dir, class_mode = 'categorical', color_mode='grayscale', target_size = cube_size, batch_size = batch_size)
valgen = MotcubeDataGenerator(horizontal_flip=True, vertical_flip=True)
val_generator = valgen.flow_from_directory(val_dir, class_mode = 'categorical', color_mode='grayscale', target_size = cube_size, batch_size = batch_size)

In [None]:
# initialize and compile the model
model = model_fcn(batch_size = batch_size, nb_classes = 3, nb_channels=1, image_x=cube_size[0], image_y=cube_size[1], image_z=cube_size[2])
model.compile(optimizer=sgd, metrics=['accuracy'], loss='categorical_crossentropy')

Now that we've instantiated our new classification model, we need to transfer weights learned by our pretrained model. Transfer learning allows for features learned in one task to be carried over as the initialization of another, speeding up or enabling otherwise untractable learning problems.

Here, we'll transfer weights learned by classifying our large simulated motion data set to a new model, which will classify our much smaller real cell motility data set.

In [None]:
# Load the pretrained model 
pretrain_model = load_model(trained_model_path)
# get weights from the pretrained and new network
w0 = pretrain_model.get_weights()
w1 = model.get_weights()

The final layers of the network are fully connected *Dense* layers, so the number of parameters will change if the size of the cube input changes.

Here, the cubes are now larger than the original `(156, 156, 101)` simulated cubes, so we can't transfer weights from the first Dense layer. 

The final Dense layers also has units `n = nb_classes`, so our 3 class simulated motion network has one more parameter than the new 2 class real cell data network.

To remedy this, we only transfer the top convolutional layer weights.

In [None]:
# Transfer conv layer weights and apply to the model
w = w0
w[16:] = w1[16:] # Dense layers are different sizes, use different weights
model.set_weights(w)

In [None]:
# Fit the model
hist = model.fit_generator(mc_generator, samples_per_epoch=mc_generator.nb_sample, nb_epoch=30, callbacks=callbacks, validation_data=val_generator, nb_val_samples=val_generator.nb_sample)