/
rmsprop_graves.py
139 lines (111 loc) · 4.6 KB
/
rmsprop_graves.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy
import chainer
from chainer.backends import cuda
from chainer import optimizer
from chainer import types
if types.TYPE_CHECKING:
import typing_extensions as tpe
class RMSpropGravesHyperparameter(tpe.Protocol):
"""Protocol class for hyperparameter of Alex Graves's RMSprop.
This is only for PEP 544 compliant static type checkers.
"""
lr = None # type: float
alpha = None # type: float
momentum = None # type: float
eps = None # type: float
_default_hyperparam = optimizer.Hyperparameter() # type: RMSpropGravesHyperparameter # NOQA
_default_hyperparam.lr = 1e-4
_default_hyperparam.alpha = 0.95
_default_hyperparam.momentum = 0.9
_default_hyperparam.eps = 1e-4
class RMSpropGravesRule(optimizer.UpdateRule):
"""Update rule for Alex Graves's RMSprop.
See :class:`~chainer.optimizers.RMSpropGraves` for the default values of
the hyperparameters.
Args:
parent_hyperparam (~chainer.optimizer.Hyperparameter): Hyperparameter
that provides the default values.
lr (float): Learning rate.
alpha (float): Exponential decay rate of the first and second order
moments of the raw gradient.
momentum (float): Exponential decay rate of the first order moment of
the adjusted gradient.
eps (float): Small value for the numerical stability.
"""
is_elementwise = True
_kernel = None
def __init__(self, parent_hyperparam=None,
lr=None, alpha=None, momentum=None, eps=None):
super(RMSpropGravesRule, self).__init__(
parent_hyperparam or _default_hyperparam)
if lr is not None:
self.hyperparam.lr = lr
if alpha is not None:
self.hyperparam.alpha = alpha
if momentum is not None:
self.hyperparam.momentum = momentum
if eps is not None:
self.hyperparam.eps = eps
def init_state(self, param):
with chainer.using_device(param.device):
xp = param.device.xp
self.state['n'] = xp.zeros_like(param.data)
self.state['g'] = xp.zeros_like(param.data)
self.state['delta'] = xp.zeros_like(param.data)
def update_core_cpu(self, param):
grad = param.grad
if grad is None:
return
n, g, delta = self.state['n'], self.state['g'], self.state['delta']
hp = self.hyperparam
n *= hp.alpha
n += (1 - hp.alpha) * grad * grad
g *= hp.alpha
g += (1 - hp.alpha) * grad
delta *= hp.momentum
delta -= hp.lr * grad / numpy.sqrt(n - g * g + hp.eps)
param.data += delta
def update_core_gpu(self, param):
grad = param.grad
if grad is None:
return
hp = self.hyperparam
if RMSpropGravesRule._kernel is None:
RMSpropGravesRule._kernel = cuda.elementwise(
'T grad, T lr, T alpha, T momentum, T eps',
'T param, T avg_n, T avg_g, T delta',
'''avg_n = alpha * avg_n + (1 - alpha) * grad * grad;
avg_g = alpha * avg_g + (1 - alpha) * grad;
delta = delta * momentum -
lr * grad * rsqrt(avg_n - avg_g * avg_g + eps);
param += delta;''',
'rmsprop_graves')
RMSpropGravesRule._kernel(
grad, hp.lr, hp.alpha, hp.momentum, hp.eps, param.data,
self.state['n'], self.state['g'], self.state['delta'])
class RMSpropGraves(optimizer.GradientMethod):
"""Alex Graves's RMSprop.
See: https://arxiv.org/abs/1308.0850
Args:
lr (float): Learning rate.
alpha (float): Exponential decay rate of the first and second order
moments of the raw gradient.
momentum (float): Exponential decay rate of the first order moment of
the adjusted gradient.
eps (float): Small value for the numerical stability.
"""
def __init__(self, lr=_default_hyperparam.lr,
alpha=_default_hyperparam.alpha,
momentum=_default_hyperparam.momentum,
eps=_default_hyperparam.eps):
super(RMSpropGraves, self).__init__()
self.hyperparam.lr = lr
self.hyperparam.alpha = alpha
self.hyperparam.momentum = momentum
self.hyperparam.eps = eps
lr = optimizer.HyperparameterProxy('lr')
alpha = optimizer.HyperparameterProxy('alpha')
momentum = optimizer.HyperparameterProxy('momentum')
eps = optimizer.HyperparameterProxy('eps')
def create_update_rule(self):
return RMSpropGravesRule(self.hyperparam)