In [None]:
'''
Train a Lanternfish autoencoder model

Lanternfish includes networks to learn representations of feature space 
based on 3D spatial representations of trajectories in 2 dimensions.

This notebook outlines how to train an autoencoder on simulated 
random walk, power flier, and fractal Brownian motion data.
'''
from motcube_preprocessing import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, CSVLogger
import numpy as np

# Lanternfish includes multiple autoencoder models, utilizing stacked
# 3D-convolutions, max-pooling, and upsampling layers
# motcube_ae is a motion cube autoencoder for (156,156,100) cubes
# ZeroPadding layers can be altered to accomodate alternative sizes
from bestiary import motcube_ae as model_fcn

In [None]:
# Define a wrapper to provide data generator inputs as the desired output
def auto_gt_generator(generator):
    '''
    Wraps a generator to provide the same output twice in a tuple
    Useful for training with input as ground truth
    '''
    for batch in generator:
        yield (batch, batch)

In [None]:
# Set directories containing training and val data
train_dir = '/path/to/ae_data/train'
val_dir = '/path/to/ae_data/val'
# Set training parameters
batch_size = 12
cube_size = (156,156,100)
file_name_save = 'autoencoder.h5'

In [None]:
# Set model callbacks
callbacks = [ModelCheckpoint(file_name_save, monitor = 'val_loss', verbose = 0, save_best_only = True, mode = 'auto'),
    LearningRateScheduler(sched),
    EarlyStopping(monitor='val_loss', patience=3),
    CSVLogger(filename=file_name_save[:-3]+'_train_log.csv', separator=',')]

In [None]:
# Instantiate data generator objects using MotcubeDataGenerator()
mcgen = MotcubeDataGenerator()
mc_generator = mcgen.flow_from_directory(train_dir, class_mode = None, color_mode='grayscale', target_size = cube_size, batch_size = batch_size)
valgen = MotcubeDataGenerator()
val_generator = valgen.flow_from_directory(val_dir, class_mode = None, color_mode='grayscale', target_size = cube_size, batch_size = batch_size)

# Wrap generators so input is returned as desired output
train_ae_gen = auto_gt_generator(mc_generator)
val_ae_gen = auto_gt_generator(val_generator)

In [None]:
# compile and fit model
model = model_fcn(batch_size = batch_size, nb_channels=1, image_x=cube_size[0], image_y=cube_size[1], image_z=cube_size[2])

# NOTE: binary_crossentropy is only applicable as a loss function if the inputs
# use binary representations
# For motcubes using non-binary kernels, change the loss to 'mse'
# (mean squared error)
model.compile(optimizer='adadelta', loss='binary_crossentropy')
hist = model.fit_generator(train_ae_gen, samples_per_epoch=mc_generator.nb_sample//2, nb_epoch=30, callbacks=callbacks, validation_data=val_ae_gen, nb_val_samples=val_generator.nb_sample)