Skip to content

Commit

Permalink
improved RAdam optimizer, fixed some small issues
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin.yang authored and Jopyth committed Feb 7, 2020
1 parent d94f09e commit 03b6e72
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion example/bmxnet-examples
7 changes: 4 additions & 3 deletions python/mxnet/optimizer/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""Contrib optimizers."""
from ..ndarray import (NDArray, clip, contrib, mean, sqrt, square, zeros)
from .optimizer import Optimizer
import math

# convenience wrapper for Optimizer.Register
register = Optimizer.register # pylint: disable=invalid-name
Expand Down Expand Up @@ -124,7 +125,7 @@ class Radam(Optimizer):
N_sma_threshhold : float, optional
Adjustable threshold for adaptive Adam
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
def __init__(self, learning_rate=0.001, beta1=0.95, beta2=0.999, epsilon=1e-5,
N_sma_threshhold=5, **kwargs):
super(Radam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
Expand Down Expand Up @@ -159,15 +160,15 @@ def update(self, index, weight, grad, state):
#
buffered = self.radam_buffer[int(t % 10)]
if t == buffered[0]:
N_sma, step_size = buffered[1], buffer[2]
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = t
beta2_t = pow(self.beta2, t)
N_sma_max = 2. / (1. - self.beta2) - 1.
N_sma = N_sma_max - 2. * t * beta2_t / (1. - beta2_t)
buffered[1] = N_sma
if N_sma > self.N_sma_threshhold:
step_size = sqrt((1. - beta2_t) * (N_sma - 4.) / (N_sma_max - 4.) * (N_sma - 2.) / N_sma * N_sma_max / (N_sma_max - 2.)) / (1. - pow(self.beta1, t))
step_size = math.sqrt((1. - beta2_t) * (N_sma - 4.) / (N_sma_max - 4.) * (N_sma - 2.) / N_sma * N_sma_max / (N_sma_max - 2.)) / (1. - pow(self.beta1, t))
else:
step_size = 1. / (1. - pow(self.beta1, t))
buffered[2] = step_size
Expand Down

0 comments on commit 03b6e72

Please sign in to comment.