In [None]:
'''
Train CNNs for use with @jkimmel's fork of @vanvalen's DeepCell

DeepCell trains standard CNNs on small receptive fields, then
transfers the weights from these standard networks to a corresponding
architecture that uses 'dilated' or 'atrous' kernels which operate
on full sized images. 

Use of 'atrous' kernels on a large image is equivalent to processing
each receptive field 'patchwise', such that weights can be transferred
from these small, quick-to-train networks, over to the atrous kernel
network. 
The foundations behind this approach are described here:
https://arxiv.org/abs/1412.4526

This notebook outlines how to train a vanilla CNN on prepared 
training data, as described in 00_generate_training_data.ipynb
'''

from __future__ import print_function, division #python2 compatability
from keras.optimizers import SGD, RMSprop
from cnn_functions import rate_scheduler, train_model_sample
import os
import datetime
import numpy as np

# Import the model you wish to train
# DeepCell models are listed in the model zoo, each with a specific
# receptive field size
# As a starting point, try the batch normalized model for your desired
# receptive field size
# i.e. bn_feature_net_NNxNN
from model_zoo import bn_feature_net_81x81 as the_model

In [None]:
# define a batch size and the number of epochs to run
# batch sizes should be set as large as possible, with GPU RAM being
# the limiting factor
# larger batchsizes improves gradient estimation, usually improving training
batch_size = 256
# number of times the network will be shown the same training data
n_epoch = 50

# Specify the name of the dataset 
# i.e. the filename of the training data, without the extension
dataset = 'training_data'
# set a name for the experiment
expt = 'exp_name'

# set the directory to save the model weights
direc_save = '/path/to/saved/models/'
# directory containing the dataset
direc_data = '/path/to/dataset/'

In [None]:
# Set the optimizer
# SGD works best for batchnorm nets, while RMSprop seems to be better
# for non-normalized nets
# here we set SGD with nesterov momentum and a scheduled decay
optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
lr_sched = rate_scheduler(lr = 0.01, decay = 0.95)

In [None]:
# set how many models you'd like to train
# training multiple models may improve prediction performance,
# based on the ensemble effect
# see:
# https://en.wikipedia.org/wiki/Boosting_(machine_learning)?oldformat=true
nb_models = 1

for iterate in range(1,nb_models):

    model = the_model(n_channels = 1, n_features = 2, reg = 1e-5)

    train_model_sample(model = model, dataset = dataset, optimizer = optimizer,
        expt = expt, it = iterate, batch_size = batch_size, n_epoch = n_epoch,
        direc_save = direc_save,
        direc_data = direc_data,
        lr_sched = lr_sched,
        rotate = True, flip = True, shear = False)

    del model
    # reset the keras numbering scheme to ensure layers are named properly
    # when training >1 model in a run
    from keras.backend.common import _UID_PREFIXES
    for key in _UID_PREFIXES.keys():
        _UID_PREFIXES[key] = 0