Skip to content

Commit

Permalink
Merge pull request #7694 from emcastillo/fix_adam_test
Browse files Browse the repository at this point in the history
Fix Adam FP16 overflow on gpu kernels
  • Loading branch information
niboshi committed Jul 19, 2019
2 parents d734995 + 1eed627 commit 02d1f0c
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions chainer/optimizers/adam.py
Expand Up @@ -240,12 +240,18 @@ def update_core_gpu(self, param):
'T eps, T eta, T weight_decay_rate, raw T dummy',
'P param, P m, P v, P vhat',
'''T grad_ = static_cast<T>(grad);
m += one_minus_beta1 * (grad_ - m);
v += one_minus_beta2 * (grad_ * grad - v);
vhat = max(vhat, v);
T m_ = static_cast<T>(m);
T v_ = static_cast<T>(v);
T vhat_ = static_cast<T>(vhat);
m_ += one_minus_beta1 * (grad_ - m_);
v_ += one_minus_beta2 * (grad_ * grad_ - v_);
vhat_ = max(vhat_, v_);
vhat = static_cast<T>(vhat_);
m = static_cast<P>(m_);
v = static_cast<P>(v_);
param -= eta *
(max(min(alpha_t / (sqrt(vhat) + eps), upper),
lower) * m + weight_decay_rate * param);''',
(max(min(alpha_t / (sqrt(vhat_) + eps), upper),
lower) * m_ + weight_decay_rate * param);''',
'amsbound')
AdamRule._amsbound_kernel(
grad, self.alpha_t, 1 - hp.beta1,
Expand All @@ -261,11 +267,15 @@ def update_core_gpu(self, param):
'T eps, T eta, T weight_decay_rate, raw T dummy',
'P param, P m, P v',
'''T grad_ = static_cast<T>(grad);
m += one_minus_beta1 * (grad_ - m);
v += one_minus_beta2 * (grad_ * grad_ - v);
T m_ = static_cast<T>(m);
T v_ = static_cast<T>(v);
m_ += one_minus_beta1 * (grad_ - m_);
v_ += one_minus_beta2 * (grad_ * grad_ - v_);
m = static_cast<P>(m_);
v = static_cast<P>(v_);
param -= eta *
(max(min(alpha_t / (sqrt(v) + eps), upper),
lower) * m + weight_decay_rate * param);''',
(max(min(alpha_t / (sqrt(v_) + eps), upper),
lower) * m_ + weight_decay_rate * param);''',
'adabound')
AdamRule._adabound_kernel(
grad, self.alpha_t, 1 - hp.beta1,
Expand All @@ -279,10 +289,16 @@ def update_core_gpu(self, param):
'T eps, T eta, T weight_decay_rate, raw T dummy',
'P param, P m, P v, P vhat',
'''T grad_ = static_cast<T>(grad);
m += one_minus_beta1 * (grad_ - m);
v += one_minus_beta2 * (grad_ * grad_ - v);
vhat = max(vhat, v);
param -= eta * (alpha_t * m / (sqrt(vhat) + eps) +
T m_ = static_cast<T>(m);
T v_ = static_cast<T>(v);
T vhat_ = static_cast<T>(vhat);
m_ += one_minus_beta1 * (grad_ - m_);
v_ += one_minus_beta2 * (grad_ * grad_ - v_);
vhat_ = max(vhat_, v_);
vhat = static_cast<T>(vhat_);
m = static_cast<P>(m_);
v = static_cast<P>(v_);
param -= eta * (alpha_t * m_ / (sqrt(vhat_) + eps) +
weight_decay_rate * param);''',
'adam')
AdamRule._amsgrad_kernel(
Expand All @@ -298,9 +314,13 @@ def update_core_gpu(self, param):
'T eps, T eta, T weight_decay_rate, raw T dummy',
'P param, P m, P v',
'''T grad_ = static_cast<T>(grad);
m += one_minus_beta1 * (grad_ - m);
v += one_minus_beta2 * (grad_ * grad_ - v);
param -= eta * (alpha_t * m / (sqrt(v) + eps) +
T m_ = static_cast<T>(m);
T v_ = static_cast<T>(v);
m_ += one_minus_beta1 * (grad_ - m_);
v_ += one_minus_beta2 * (grad_ * grad_ - v_);
m = static_cast<P>(m_);
v = static_cast<P>(v_);
param -= eta * (alpha_t * m_ / (sqrt(v_) + eps) +
weight_decay_rate * param);''',
'adam')
AdamRule._kernel(
Expand Down

0 comments on commit 02d1f0c

Please sign in to comment.