Skip to content

Commit

Permalink
Trainer (#11)
Browse files Browse the repository at this point in the history
* model trainer to stop at a lower loss

* lr_scheduler adam kernel_initializer

* callbacks

* trainer
  • Loading branch information
haifeng-jin committed Apr 19, 2018
1 parent 912bfc4 commit c99295e
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 39 deletions.
4 changes: 3 additions & 1 deletion autokeras/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@
DATA_AUGMENTATION = True
MAX_ITER_NUM = 200
MIN_LOSS_DEC = 1e-4
MAX_NO_IMPROVEMENT_NUM = 10
MAX_NO_IMPROVEMENT_NUM = 100
EPOCHS_EACH = 1
MAX_BATCH_SIZE = 32
LIMIT_MEMORY = False
2 changes: 1 addition & 1 deletion autokeras/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def to_real_layer(layer):
if is_layer(layer, 'Dense'):
return Dense(layer.units, activation=layer.activation)
if is_layer(layer, 'Conv'):
return layer.func(layer.filters, kernel_size=layer.kernel_size, padding='same')
return layer.func(layer.filters, kernel_size=layer.kernel_size, padding='same', kernel_initializer='he_normal')
if is_layer(layer, 'Pooling'):
return layer.func(padding='same')
if is_layer(layer, 'BatchNormalization'):
Expand Down
5 changes: 0 additions & 5 deletions autokeras/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import numpy as np

from keras.losses import categorical_crossentropy
from keras.models import load_model
from keras.optimizers import Adadelta
from keras import backend
from keras.utils import plot_model

Expand Down Expand Up @@ -70,9 +68,6 @@ def add_model(self, model, x_train, y_train, x_test, y_test):
Returns:
History object.
"""
model.compile(loss=categorical_crossentropy,
optimizer=Adadelta(),
metrics=['accuracy'])
if self.verbose:
model.summary()
ModelTrainer(model, x_train, y_train, x_test, y_test, self.verbose).train_model()
Expand Down
130 changes: 99 additions & 31 deletions autokeras/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
import pickle
import numpy as np

import tensorflow as tf
from keras import backend
from keras.callbacks import Callback, LearningRateScheduler, ReduceLROnPlateau
from keras.losses import categorical_crossentropy
from keras.layers import Conv1D, Conv2D, Conv3D, MaxPooling3D, MaxPooling2D, MaxPooling1D, Dense, BatchNormalization, \
Concatenate, Dropout, Activation, Flatten
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from tensorflow import Dimension

Expand Down Expand Up @@ -32,6 +37,67 @@ def get_conv_layer_func(n_dim):
return conv_layer_functions[n_dim - 1]


def lr_schedule(epoch):
"""Learning Rate Schedule
Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
Called automatically every epoch as part of callbacks during training.
# Arguments
epoch (int): The number of epochs
# Returns
lr (float32): learning rate
"""
lr = 1e-3
if epoch > 180:
lr *= 0.5e-3
elif epoch > 160:
lr *= 1e-3
elif epoch > 120:
lr *= 1e-2
elif epoch > 80:
lr *= 1e-1
return lr


class NoImprovementError(Exception):
def __init__(self, message):
self.message = message


class EarlyStop(Callback):
def __init__(self, max_no_improvement_num=constant.MAX_NO_IMPROVEMENT_NUM, min_loss_dec=constant.MIN_LOSS_DEC):
super().__init__()
self.training_losses = []
self.minimum_loss = None
self._no_improvement_count = 0
self._max_no_improvement_num = max_no_improvement_num
self._done = False
self._min_loss_dec = min_loss_dec

def on_train_begin(self, logs=None):
self.training_losses = []
self._no_improvement_count = 0
self._done = False
self.minimum_loss = float('inf')

def on_epoch_end(self, batch, logs=None):
loss = logs.get('val_loss')
self.training_losses.append(loss)
if self._done and loss > (self.minimum_loss - self._min_loss_dec):
raise NoImprovementError('No improvement for {} epochs.'.format(self._max_no_improvement_num))

if loss > (self.minimum_loss - self._min_loss_dec):
self._no_improvement_count += 1
else:
self._no_improvement_count = 0
self.minimum_loss = loss

if self._no_improvement_count > self._max_no_improvement_num:
self._done = True


class ModelTrainer:
"""A class that is used to train model
Expand All @@ -44,21 +110,19 @@ class ModelTrainer:
x_test: the input test data
y_test: the input test data labels
verbose: verbosity mode
training_losses: a list to store all losses during training
minimum_loss: the minimum loss during training
_no_improvement_count: the number of iterations that don't improve the result
"""

def __init__(self, model, x_train, y_train, x_test, y_test, verbose):
"""Init ModelTrainer with model, x_train, y_train, x_test, y_test, verbose"""
model.compile(loss=categorical_crossentropy,
optimizer=Adam(lr=lr_schedule(0)),
metrics=['accuracy'])
self.model = model
self.x_train = x_train
self.y_train = y_train
self.x_test = x_test
self.y_test = y_test
self.verbose = verbose
self.training_losses = []
self.minimum_loss = None
self._no_improvement_count = 0
if constant.DATA_AUGMENTATION:
self.datagen = ImageDataGenerator(
# set input mean to 0 over the dataset
Expand Down Expand Up @@ -87,39 +151,43 @@ def __init__(self, model, x_train, y_train, x_test, y_test, verbose):

def _converged(self, loss):
"""Return whether the training is converged"""
self.training_losses.append(loss)
if loss > (self.minimum_loss - constant.MIN_LOSS_DEC):
self._no_improvement_count += 1
else:
self._no_improvement_count = 0

if loss < self.minimum_loss:
self.minimum_loss = loss

return self._no_improvement_count > constant.MAX_NO_IMPROVEMENT_NUM

def train_model(self):
"""Train the model with dataset and return the minimum_loss"""
self.training_losses = []
self._no_improvement_count = 0
self.minimum_loss = float('inf')
batch_size = min(self.x_train.shape[0], 200)
if constant.DATA_AUGMENTATION:
flow = self.datagen.flow(self.x_train, self.y_train, batch_size)
else:
flow = None
for _ in range(constant.MAX_ITER_NUM):
batch_size = min(self.x_train.shape[0], constant.MAX_BATCH_SIZE)
terminator = EarlyStop()
lr_scheduler = LearningRateScheduler(lr_schedule)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
cooldown=0,
patience=5,
min_lr=0.5e-6)

callbacks = [terminator, lr_scheduler, lr_reducer]
if constant.LIMIT_MEMORY:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
backend.set_session(sess)
try:
if constant.DATA_AUGMENTATION:
self.model.fit_generator(flow, epochs=constant.EPOCHS_EACH)
flow = self.datagen.flow(self.x_train, self.y_train, batch_size)
self.model.fit_generator(flow,
epochs=constant.MAX_ITER_NUM,
validation_data=(self.x_test, self.y_test),
callbacks=callbacks,
verbose=self.verbose)
else:
self.model.fit(self.x_train, self.y_train,
batch_size=batch_size,
epochs=constant.EPOCHS_EACH,
epochs=constant.MAX_ITER_NUM,
validation_data=(self.x_test, self.y_test),
callbacks=callbacks,
verbose=self.verbose)
loss, _ = self.model.evaluate(self.x_test, self.y_test, verbose=self.verbose)
if self._converged(loss):
break
return self.minimum_loss
except NoImprovementError as e:
if self.verbose:
print('Training finished!')
print(e.message)


def extract_config(network):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from unittest.mock import patch
from autokeras.generator import RandomConvClassifierGenerator
from autokeras.utils import *
import numpy as np
Expand All @@ -9,8 +10,10 @@ def test_model_trainer():
np.random.rand(1, 3), False).train_model()


def test_model_trainer_not_augmented():
@patch('autokeras.utils.backend')
def test_model_trainer_not_augmented(_):
constant.DATA_AUGMENTATION = False
constant.LIMIT_MEMORY = True
model = RandomConvClassifierGenerator(3, (28, 28, 1)).generate()
ModelTrainer(model, np.random.rand(2, 28, 28, 1), np.random.rand(2, 3), np.random.rand(1, 28, 28, 1),
np.random.rand(1, 3), False).train_model()

0 comments on commit c99295e

Please sign in to comment.