Skip to content

Commit

Permalink
close pytorch#23796, eps_inside_sqrt option for rmsprop, reduce cente…
Browse files Browse the repository at this point in the history
…red version memory usage
  • Loading branch information
meijieru committed Aug 5, 2019
1 parent a3c165f commit ea62024
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torch/optim/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ class RMSprop(Optimizer):
alpha (float, optional): smoothing constant (default: 0.99)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
eps_inside_sqrt (float, optional): if ``True``, add eps inside the sqrt.
(default: False)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
"""

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, eps_inside_sqrt=False, weight_decay=0, momentum=0, centered=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
Expand All @@ -37,7 +39,7 @@ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, moment
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))

defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, eps_inside_sqrt=eps_inside_sqrt, centered=centered, weight_decay=weight_decay)
super(RMSprop, self).__init__(params, defaults)

def __setstate__(self, state):
Expand Down Expand Up @@ -88,9 +90,15 @@ def step(self, closure=None):
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.mul_(alpha).add_(1 - alpha, grad)
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
if group['eps_inside_sqrt']:
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add_(group['eps']).sqrt_()
else:
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt_().add_(group['eps'])
else:
avg = square_avg.sqrt().add_(group['eps'])
if group['eps_inside_sqrt']:
avg = square_avg.add(group['eps']).sqrt_()
else:
avg = square_avg.sqrt().add_(group['eps'])

if group['momentum'] > 0:
buf = state['momentum_buffer']
Expand Down

0 comments on commit ea62024

Please sign in to comment.