In [None]:
'''
Train a Lanternfish classification CNN


This notebook outlines how to train a CNN classifier to discrimate
simulated random walks, power fliers, and fractal Brownian motion.
'''

from motcube_preprocessing import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, CSVLogger
from keras.optimizers import SGD
import numpy as np

# Lanternfish models are stored in bestiary
# multicontextL classifies an arbitrary number of classes using stacked
# 3D-convolutional layers
from bestiary import multi_contextL as model_fcn

In [None]:
# Define a learning rate schedule function
# and set the optimizer
def lr_schedule(rate=0.01, decay=0.8):
    '''
    Generates a schedule function with exp decay

    alpha_new = alpha_init * decay_coeff^epoch

    Parameters
    ----------
    rate : float.
        learning rate.
    decay : float.
        decay coefficient
    '''
    def sched(epoch):
        return (rate * (decay**np.int(epoch)))
    return sched

sched = lr_schedule(rate = 0.005, decay = 0.8)
sgd = SGD(momentum = 0.5)

In [None]:
# Set data directories and training parameters
train_dir = '/path/to/sim_data/train'
val_dir = '/path/to/sim_data/val'
batch_size = 12
cube_size = (156,156,101) # (x, y, t)
file_name_save = 'multiclass.h5'

In [None]:
# Set training callbacks for keras .fit_generator() 
#
# ModelCheckpoint saves model with lowest val loss
# LearningRateScheduler applies the scheduling function
# EarlyStopping halts training if val loss stops improving
# CSVLogger writes training data to a CSV
callbacks = [ModelCheckpoint(file_name_save, monitor = 'val_loss', verbose = 0, save_best_only = True, mode = 'auto'),
    LearningRateScheduler(sched),
    EarlyStopping(monitor='val_loss', patience=3),
    CSVLogger(filename=file_name_save[:-3]+'_train_log.csv', separator=',')]

In [None]:
# Set up motion cube data generators using MotcubeDataGenerator()
# from motcube_preprocessing.py
mcgen = MotcubeDataGenerator()
mc_generator = mcgen.flow_from_directory(train_dir, class_mode = 'categorical', color_mode='grayscale', target_size = cube_size, batch_size = batch_size)
valgen = MotcubeDataGenerator()
val_generator = valgen.flow_from_directory(val_dir, class_mode = 'categorical', color_mode='grayscale', target_size = cube_size, batch_size = batch_size)

In [None]:
# initialize and compile the model
model = model_fcn(batch_size = batch_size, nb_classes = 3, nb_channels=1, image_x=cube_size[0], image_y=cube_size[1], image_z=cube_size[2])
model.compile(optimizer=sgd, metrics=['accuracy'], loss='categorical_crossentropy')

In [None]:
# train the model
hist = model.fit_generator(mc_generator, samples_per_epoch=mc_generator.nb_sample//4, nb_epoch=30, callbacks=callbacks, validation_data=val_generator, nb_val_samples=val_generator.nb_sample)