-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
/
Copy pathrmsprop_tf.py
139 lines (114 loc) · 6 KB
/
rmsprop_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
""" RMSProp modified to behave like Tensorflow impl
Originally cut & paste from PyTorch RMSProp
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
Modifications Copyright 2021 Ross Wightman
"""
import torch
from torch.optim import Optimizer
class RMSpropTF(Optimizer):
"""Implements RMSprop algorithm (TensorFlow style epsilon)
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
and a few other modifications to closer match Tensorflow for matching hyper-params.
Noteworthy changes include:
1. Epsilon applied inside square-root
2. square_avg initialized to ones
3. LR scaling of update accumulated in momentum buffer
Proposed by G. Hinton in his
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing (decay) constant (default: 0.9)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-10)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
update as per defaults in Tensorflow
"""
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
decoupled_decay=False, lr_in_momentum=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state):
super(RMSpropTF, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('momentum', 0)
group.setdefault('centered', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
if group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(p)
if group['centered']:
state['grad_avg'] = torch.zeros_like(p)
square_avg = state['square_avg']
one_minus_alpha = 1. - group['alpha']
state['step'] += 1
if group['weight_decay'] != 0:
if group['decoupled_decay']:
p.mul_(1. - group['lr'] * group['weight_decay'])
else:
grad = grad.add(p, alpha=group['weight_decay'])
# Tensorflow order of ops for updating squared avg
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
else:
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
p.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.add_(buf, alpha=-group['lr'])
else:
p.addcdiv_(grad, avg, value=-group['lr'])
return loss