Skip to content

Commit

Permalink
implemented RAdam optimizer, not tested
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin.yang authored and Jopyth committed Feb 7, 2020
1 parent ab312c5 commit d94f09e
Showing 1 changed file with 80 additions and 1 deletion.
81 changes: 80 additions & 1 deletion python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name

__all__ = ['GroupAdaGrad']
__all__ = ['GroupAdaGrad', 'Radam']


@register
Expand Down Expand Up @@ -98,3 +98,82 @@ def update(self, index, weight, grad, state):
state[:] += mean(square(grad), axis=1, keepdims=True)
div = lr * grad / sqrt(state + self.float_stable_eps)
weight[:] -= div


@register
class Radam(Optimizer):
"""The RAdam optimizer.
A new variant of Adam, by introducing a term to rectify the variance
of the adaptive learning rate.
Paper: "On the Variance of the Adaptive Learning Rate and Beyond", Liu et al. 2019,
link: https://arxiv.org/abs/1908.03265
This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.
Parameters
----------
beta1 : float, optional
Exponential decay rate for the first moment estimates.
beta2 : float, optional
Exponential decay rate for the second moment estimates.
epsilon : float, optional
Small value to avoid division by 0.
N_sma_threshhold : float, optional
Adjustable threshold for adaptive Adam
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
N_sma_threshhold=5, **kwargs):
super(Radam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.N_sma_threshhold = N_sma_threshhold
self.radam_buffer = [[None,None,None] for ind in range(10)]

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

t = self._index_update_count[index]

# preprocess grad
grad = grad * self.rescale_grad + wd * weight
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)

# update m_t and v_t
m_t, v_t = state
m_t[:] = self.beta1 * m_t + (1. - self.beta1) * grad
v_t[:] = self.beta2 * v_t + (1. - self.beta2) * grad * grad

#
buffered = self.radam_buffer[int(t % 10)]
if t == buffered[0]:
N_sma, step_size = buffered[1], buffer[2]
else:
buffered[0] = t
beta2_t = pow(self.beta2, t)
N_sma_max = 2. / (1. - self.beta2) - 1.
N_sma = N_sma_max - 2. * t * beta2_t / (1. - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = sqrt((1. - beta2_t) * (N_sma - 4.) / (N_sma_max - 4.) * (N_sma - 2.) / N_sma * N_sma_max / (N_sma_max - 2.)) / (1. - pow(self.beta1, t))
else:
step_size = 1. / (1. - pow(self.beta1, t))
buffered[2] = step_size

if N_sma > self.N_sma_threshhold:
denom = sqrt(v_t) + self.epsilon
weight[:] -= (step_size * lr) * m_t / denom
else:
weight[:] -= (step_size * lr) * m_t

0 comments on commit d94f09e

Please sign in to comment.