Skip to content

Commit

Permalink
Add support for LR decay in all optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 10, 2016
1 parent 79edae5 commit b2e8d5a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 11 deletions.
66 changes: 55 additions & 11 deletions keras/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,16 @@ def __init__(self, lr=0.01, momentum=0., decay=0.,
self.lr = K.variable(lr)
self.momentum = K.variable(momentum)
self.decay = K.variable(decay)
self.inital_decay = decay

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
lr = self.lr * (1. / (1. + self.decay * self.iterations))
self.updates = [K.update_add(self.iterations, 1)]
self.updates = []

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
self.updates .append(K.update_add(self.iterations, 1))

# momentum
shapes = [K.get_variable_shape(p) for p in params]
Expand Down Expand Up @@ -185,12 +190,17 @@ class RMSprop(Optimizer):
lr: float >= 0. Learning rate.
rho: float >= 0.
epsilon: float >= 0. Fuzz factor.
decay: float >= 0. Learning rate decay over each update.
'''
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, **kwargs):
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0.,
**kwargs):
super(RMSprop, self).__init__(**kwargs)
self.__dict__.update(locals())
self.lr = K.variable(lr)
self.rho = K.variable(rho)
self.decay = K.variable(decay)
self.inital_decay = decay
self.iterations = K.variable(0.)

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
Expand All @@ -199,11 +209,16 @@ def get_updates(self, params, constraints, loss):
self.weights = accumulators
self.updates = []

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
self.updates.append(K.update_add(self.iterations, 1))

for p, g, a in zip(params, grads, accumulators):
# update accumulator
new_a = self.rho * a + (1. - self.rho) * K.square(g)
self.updates.append(K.update(a, new_a))
new_p = p - self.lr * g / (K.sqrt(new_a) + self.epsilon)
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)

# apply constraints
if p in constraints:
Expand Down Expand Up @@ -233,10 +248,13 @@ class Adagrad(Optimizer):
# References
- [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
'''
def __init__(self, lr=0.01, epsilon=1e-8, **kwargs):
def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs):
super(Adagrad, self).__init__(**kwargs)
self.__dict__.update(locals())
self.lr = K.variable(lr)
self.decay = K.variable(decay)
self.inital_decay = decay
self.iterations = K.variable(0.)

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
Expand All @@ -245,10 +263,15 @@ def get_updates(self, params, constraints, loss):
self.weights = accumulators
self.updates = []

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
self.updates.append(K.update_add(self.iterations, 1))

for p, g, a in zip(params, grads, accumulators):
new_a = a + K.square(g) # update accumulator
self.updates.append(K.update(a, new_a))
new_p = p - self.lr * g / (K.sqrt(new_a) + self.epsilon)
new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
# apply constraints
if p in constraints:
c = constraints[p]
Expand Down Expand Up @@ -278,10 +301,14 @@ class Adadelta(Optimizer):
# References
- [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701)
'''
def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, **kwargs):
def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0.,
**kwargs):
super(Adadelta, self).__init__(**kwargs)
self.__dict__.update(locals())
self.lr = K.variable(lr)
self.decay = K.variable(decay)
self.inital_decay = decay
self.iterations = K.variable(0.)

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
Expand All @@ -291,6 +318,11 @@ def get_updates(self, params, constraints, loss):
self.weights = accumulators + delta_accumulators
self.updates = []

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))
self.updates.append(K.update_add(self.iterations, 1))

for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
# update accumulator
new_a = self.rho * a + (1. - self.rho) * K.square(g)
Expand All @@ -299,7 +331,7 @@ def get_updates(self, params, constraints, loss):
# use the new accumulator and the *old* delta_accumulator
update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)

new_p = p - self.lr * update
new_p = p - lr * update
# apply constraints
if p in constraints:
c = constraints[p]
Expand Down Expand Up @@ -333,20 +365,26 @@ class Adam(Optimizer):
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, **kwargs):
epsilon=1e-8, decay=0., **kwargs):
super(Adam, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.decay = K.variable(decay)
self.inital_decay = decay

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))

t = self.iterations + 1
lr_t = self.lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))
lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))

shapes = [K.get_variable_shape(p) for p in params]
ms = [K.zeros(shape) for shape in shapes]
Expand Down Expand Up @@ -393,18 +431,24 @@ class Adamax(Optimizer):
- [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8)
'''
def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999,
epsilon=1e-8, **kwargs):
epsilon=1e-8, decay=0., **kwargs):
super(Adamax, self).__init__(**kwargs)
self.__dict__.update(locals())
self.iterations = K.variable(0.)
self.lr = K.variable(lr)
self.beta_1 = K.variable(beta_1)
self.beta_2 = K.variable(beta_2)
self.decay = K.variable(decay)
self.inital_decay = decay

def get_updates(self, params, constraints, loss):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]

lr = self.lr
if self.inital_decay > 0:
lr *= (1. / (1. + self.decay * self.iterations))

t = self.iterations + 1
lr_t = self.lr / (1. - K.pow(self.beta_1, t))

Expand Down
5 changes: 5 additions & 0 deletions tests/keras/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,27 @@ def test_sgd():

def test_rmsprop():
_test_optimizer(RMSprop())
_test_optimizer(RMSprop(decay=1e-3))


def test_adagrad():
_test_optimizer(Adagrad())
_test_optimizer(Adagrad(decay=1e-3))


def test_adadelta():
_test_optimizer(Adadelta())
_test_optimizer(Adadelta(decay=1e-3))


def test_adam():
_test_optimizer(Adam())
_test_optimizer(Adam(decay=1e-3))


def test_adamax():
_test_optimizer(Adamax())
_test_optimizer(Adamax(decay=1e-3))


def test_nadam():
Expand Down

0 comments on commit b2e8d5a

Please sign in to comment.