Skip to content

Commit

Permalink
Fix RMSProp update rule (apache#6235)
Browse files Browse the repository at this point in the history
* Fix RMSProp update rule

Follow the formula presents in Alex's paper,
this prevents taking square root of a negative
value (caused by arithmetic error).

* Fix the formula of non centered version of RMSProp

* Fix RMSProp update rule in python test

* Fix RMSProp update rule in perl test
  • Loading branch information
sifmelcara authored and piiswrong committed May 13, 2017
1 parent b005306 commit 38f7c55
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
4 changes: 2 additions & 2 deletions perl-package/AI-MXNet/t/test_optimizers.t
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ method update($index, $weight, $grad, $state)
$grad = mx->nd->clip($grad, -$self->clip_gradient, $self->clip_gradient);
}
$n .= (1 - $self->gamma1) * ($grad * $grad) + $self->gamma1 * $n;
$weight -= $lr * $grad/(mx->nd->sqrt($n) + $self->epsilon);
$weight -= $lr * $grad/(mx->nd->sqrt($n + $self->epsilon));
}
else
{
Expand All @@ -177,7 +177,7 @@ method update($index, $weight, $grad, $state)
}
$n .= (1 - $self->gamma1) * ($grad * $grad) + $self->gamma1 * $n;
$g .= (1 - $self->gamma1) * $grad + $self->gamma1 * $g;
$delta .= ($self->gamma2) * $delta - $lr * $grad/(mx->nd->sqrt($n - $g*$g) + $self->epsilon);
$delta .= ($self->gamma2) * $delta - $lr * $grad/(mx->nd->sqrt($n - $g*$g + $self->epsilon));
$weight += $delta;
}
if($self->clip_weights)
Expand Down
38 changes: 20 additions & 18 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,17 +300,17 @@ inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
delta = scalar<DType>(param.gamma2) * delta -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n - state_g * state_g) +
scalar<DType>(param.epsilon)));
(F<square_root>(state_n - state_g * state_g +
scalar<DType>(param.epsilon))));
} else {
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
scalar<DType>(param.gamma1) * state_n;
state_g = scalar<DType>(1.f - param.gamma1) * grad +
scalar<DType>(param.gamma1) * state_g;
delta = scalar<DType>(param.gamma2) * delta -
scalar<DType>(param.lr) *
(grad / (F<square_root>(state_n - state_g * state_g) +
scalar<DType>(param.epsilon)));
(grad / (F<square_root>(state_n - state_g * state_g +
scalar<DType>(param.epsilon))));
}

if (param.clip_weights >= 0.0f) {
Expand Down Expand Up @@ -386,33 +386,35 @@ inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
if (param.clip_weights >= 0.0f) {
Assign(out, req[0],
F<clip>(weight -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n) +
scalar<DType>(param.epsilon))),
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))),
DType(param.clip_weights)));
} else {
Assign(out, req[0], weight -
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n) +
scalar<DType>(param.epsilon))));
scalar<DType>(param.lr) *
(F<clip>(grad, DType(param.clip_gradient)) /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))));
}
} else {
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
scalar<DType>(param.gamma1) * state_n;
if (param.clip_weights >= 0.0f) {
Assign(out, req[0],
F<clip>(weight -
scalar<DType>(param.lr) *
(grad / (F<square_root>(state_n) +
scalar<DType>(param.epsilon))),
scalar<DType>(param.lr) *
(grad /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))),
DType(param.clip_weights)));
} else {
Assign(out, req[0], weight -
scalar<DType>(param.lr) *
(grad / (F<square_root>(state_n) +
scalar<DType>(param.epsilon))));
scalar<DType>(param.lr) *
(grad /
(F<square_root>(state_n +
scalar<DType>(param.epsilon)))));
}
}
});
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,15 @@ def update(self, index, weight, grad, state):
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
n[:] = (1 - self.gamma1) * (grad * grad) + self.gamma1 * n
weight[:] -= lr * grad/(mx.nd.sqrt(n) + self.epsilon)
weight[:] -= lr * grad/(mx.nd.sqrt(n + self.epsilon))

else:
n, g, delta = state
if self.clip_gradient is not None:
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
n[:] = (1 - self.gamma1) * (grad * grad) + self.gamma1 * n
g[:] = (1 - self.gamma1) * grad + self.gamma1 * g
delta[:] = (self.gamma2) * delta - lr * grad/(mx.nd.sqrt(n - g*g) + self.epsilon)
delta[:] = (self.gamma2) * delta - lr * grad/(mx.nd.sqrt(n - g*g + self.epsilon))
weight[:] += delta

if self.clip_weights:
Expand Down

0 comments on commit 38f7c55

Please sign in to comment.