Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
110 lines (84 sloc) 3.32 KB
import chainer
from chainer.backends import cuda
from chainer.backends import intel64
from chainer import optimizer
from chainer import types
import typing_extensions as tpe
class MomentumSGDHyperparameter(tpe.Protocol):
"""Protocol class for hyperparameter of classical momentum SGD.
This is only for PEP 544 compliant static type checkers.
lr = None # type: float
momentum = None # type: float
_default_hyperparam = optimizer.Hyperparameter() # type: MomentumSGDHyperparameter # NOQA = 0.01
_default_hyperparam.momentum = 0.9
class MomentumSGDRule(optimizer.UpdateRule):
"""Update rule for the classical momentum SGD.
See :class:`~chainer.optimizers.MomentumSGD` for the default values of the
parent_hyperparam (~chainer.optimizer.Hyperparameter): Hyperparameter
that provides the default values.
lr (float): Learning rate.
momentum (float): Exponential decay rate of the first order moment.
_kernel = None
def __init__(self, parent_hyperparam=None, lr=None, momentum=None):
super(MomentumSGDRule, self).__init__(
parent_hyperparam or _default_hyperparam)
if lr is not None: = lr
if momentum is not None:
self.hyperparam.momentum = momentum
def init_state(self, param):
with chainer.using_device(param.device):
xp = param.device.xp
self.state['v'] = xp.zeros_like(
# For iDeep
if isinstance(, intel64.mdarray):
self.state['v'] = intel64.ideep.array(
self.state['v'], itype=intel64.ideep.wgt_array)
def update_core_cpu(self, param):
grad = param.grad
if grad is None:
v = self.state['v']
if isinstance(v, intel64.mdarray):
v.inplace_axpby(self.hyperparam.momentum, -, grad) += v
v *= self.hyperparam.momentum
v -= * grad += v
def update_core_gpu(self, param):
grad = param.grad
if grad is None:
if MomentumSGDRule._kernel is None:
MomentumSGDRule._kernel = cuda.elementwise(
'T grad, T lr, T momentum',
'T param, T v',
'''v = momentum * v - lr * grad;
param += v;''',
grad,, self.hyperparam.momentum,,
class MomentumSGD(optimizer.GradientMethod):
"""Momentum SGD optimizer.
lr (float): Learning rate.
momentum (float): Exponential decay rate of the first order moment.
def __init__(self,,
super(MomentumSGD, self).__init__() = lr
self.hyperparam.momentum = momentum
lr = optimizer.HyperparameterProxy('lr')
momentum = optimizer.HyperparameterProxy('momentum')
def create_update_rule(self):
return MomentumSGDRule(self.hyperparam)
You can’t perform that action at this time.