Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Stochastic Gradient Descent with Restarts - Callback #3525

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 67 additions & 0 deletions keras/callbacks.py
Expand Up @@ -516,3 +516,70 @@ def on_epoch_end(self, epoch, logs={}):
summary_value.tag = name
self.writer.add_summary(summary, epoch)
self.writer.flush()


class SGDRScheduler(Callback):
'''Schedule learning rates with restarts

A simple restart technique for stochastic gradient descent.
The learning rate decays after each batch and peridically resets to its
initial value. Optionally, the learning rate is additionally reduced by a
fixed factor at a predifined set of epochs.

# Arguments
epochsize: Number of samples per epoch during training.
batchsize: Number of samples per batch during training.
start_epoch: First epoch where decay is applied.
epochs_to_restart: Initial number of epochs before restarts.
mult_factor: Increase of epochs_to_restart after each restart.
lr_fac: Decrease of learning rate at epochs given in
lr_reduction_epochs.
lr_reduction_epochs: Fixed list of epochs at which to reduce
learning rate.

# References
- [SGDR: Stochastic Gradient Descent with Restarts](http://arxiv.org/abs/1608.03983)
'''
def __init__(self,
epochsize,
batchsize,
epochs_to_restart=2,
mult_factor=2,
lr_fac=0.1,
lr_reduction_epochs=(60, 120, 160)):
super(SGDRScheduler, self).__init__()
self.epoch = -1
self.batch_since_restart = 0
self.next_restart = epochs_to_restart
self.epochsize = epochsize
self.batchsize = batchsize
self.epochs_to_restart = epochs_to_restart
self.mult_factor = mult_factor
self.batches_per_epoch = self.epochsize / self.batchsize
self.lr_fac = lr_fac
self.lr_reduction_epochs = lr_reduction_epochs
self.lr_log = []

def on_train_begin(self, logs={}):
self.lr = K.get_value(self.model.optimizer.lr)

def on_epoch_begin(self, epoch, logs={}):
self.epoch += 1

def on_batch_end(self, batch, logs={}):
fraction_to_restart = self.batch_since_restart / \
(self.batches_per_epoch * self.epochs_to_restart)
lr = 0.5 * self.lr * (1 + np.cos(fraction_to_restart * np.pi))
K.set_value(self.model.optimizer.lr, lr)

self.batch_since_restart += 1
self.lr_log.append(lr)

def on_epoch_end(self, epoch, logs={}):
if self.epoch + 1 == self.next_restart:
self.batch_since_restart = 0
self.epochs_to_restart *= self.mult_factor
self.next_restart += self.epochs_to_restart

if (self.epoch + 1) in self.lr_reduction_epochs:
self.lr *= self.lr_fac