-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
ada_delta.py
113 lines (86 loc) · 3.44 KB
/
ada_delta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy
from chainer import backend
from chainer.backends import cuda
from chainer import optimizer
from chainer import types
if types.TYPE_CHECKING:
import typing_extensions as tpe
class AdaDeltaHyperparameter(tpe.Protocol):
"""Protocol class for hyperparameter of Zeiler's ADADELTA.
This is only for PEP 544 compliant static type checkers.
"""
rho = None # type: float
eps = None # type: float
_default_hyperparam = optimizer.Hyperparameter() # type: AdaDeltaHyperparameter # NOQA
_default_hyperparam.rho = 0.95
_default_hyperparam.eps = 1e-6
class AdaDeltaRule(optimizer.UpdateRule):
"""Update rule of Zeiler's ADADELTA.
See :class:`~chainer.optimizers.AdaDelta` for the default values of the
hyperparameters.
Args:
parent_hyperparam (~chainer.optimizer.Hyperparameter): Hyperparameter
that provides the default values.
rho (float): Exponential decay rate of the first and second order
moments.
eps (float): Small value for the numerical stability.
"""
_kernel = None
def __init__(self, parent_hyperparam=None, rho=None, eps=None):
super(AdaDeltaRule, self).__init__(
parent_hyperparam or _default_hyperparam)
if rho is not None:
self.hyperparam.rho = rho
if eps is not None:
self.hyperparam.eps = eps
def init_state(self, param):
xp = backend.get_array_module(param.data)
with cuda.get_device_from_array(param.data):
self.state['msg'] = xp.zeros_like(param.data)
self.state['msdx'] = xp.zeros_like(param.data)
def update_core_cpu(self, param):
grad = param.grad
if grad is None:
return
msg, msdx = self.state['msg'], self.state['msdx']
rho = self.hyperparam.rho
eps = self.hyperparam.eps
msg *= rho
msg += (1 - rho) * grad * grad
dx = numpy.sqrt((msdx + eps) / (msg + eps)) * grad
msdx *= rho
msdx += (1 - rho) * dx * dx
param.data -= dx
def update_core_gpu(self, param):
grad = param.grad
if grad is None:
return
if AdaDeltaRule._kernel is None:
AdaDeltaRule._kernel = cuda.elementwise(
'T grad, T one_minus_rho, T eps',
'T param, T msg, T msdx',
'''msg = msg + one_minus_rho * (grad * grad - msg);
T dx = sqrt((msdx + eps) / (msg + eps)) * grad;
msdx += one_minus_rho * (dx * dx - msdx);
param -= dx;''',
'adadelta')
AdaDeltaRule._kernel(
grad, 1 - self.hyperparam.rho, self.hyperparam.eps, param.data,
self.state['msg'], self.state['msdx'])
class AdaDelta(optimizer.GradientMethod):
"""Zeiler's ADADELTA.
See: http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf
Args:
rho (float): Exponential decay rate of the first and second order
moments.
eps (float): Small value for the numerical stability.
"""
def __init__(self, rho=_default_hyperparam.rho,
eps=_default_hyperparam.eps):
super(AdaDelta, self).__init__()
self.hyperparam.rho = rho
self.hyperparam.eps = eps
rho = optimizer.HyperparameterProxy('rho')
eps = optimizer.HyperparameterProxy('eps')
def create_update_rule(self):
return AdaDeltaRule(self.hyperparam)