Skip to content

Commit

Permalink
Squash pull request #3977 from delta2323/optimizer-hooks-3
Browse files Browse the repository at this point in the history
Squashed commit of the following:

commit 6ac796b
Merge: 50cc4cf e7e418d
Author: Ryosuke Okuta <okuta@preferred.jp>
Date:   Thu Mar 8 14:14:29 2018 +0900

    Merge pull request #3977 from delta2323/optimizer-hooks-3

    Create `chainer.optimizer_hooks` namespace and move hooks there.

commit e7e418d
Author: Kenta Oono <oono@preferred.jp>
Date:   Wed Mar 7 17:34:16 2018 +0900

    Fix documents of optimizer hooks

commit 880e5eb
Author: Kenta OONO <oono@preferred.jp>
Date:   Sat Mar 3 23:20:57 2018 +0900

    Remove unused import

commit 0d6b270
Merge: ec0d7e4 b94cd47
Author: Kenta OONO <oono@preferred.jp>
Date:   Sat Mar 3 23:20:00 2018 +0900

    Merge branch 'master' into optimizer-hooks-3

commit ec0d7e4
Merge: 588a7aa 0be1b38
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Feb 26 01:01:22 2018 +0900

    Merge branch 'master' into optimizer-hooks-3

commit 588a7aa
Author: Kenta OONO <oono@preferred.jp>
Date:   Sun Feb 25 21:59:36 2018 +0900

    Apply changes in the current master branch to GradientClipping

commit bc3492c
Author: Kenta OONO <oono@preferred.jp>
Date:   Sun Feb 25 21:49:52 2018 +0900

    s/optimizer/optimizer/hooks/ in guides

commit 026015d
Merge: f574890 72be88b
Author: Kenta OONO <oono@preferred.jp>
Date:   Sun Feb 25 21:46:16 2018 +0900

    Merge branch 'master' into optimizer-hooks-3

commit f574890
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Dec 25 15:38:09 2017 +0900

    Fix tests for deprecated optimizer hooks

commit 38b2b40
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Dec 25 15:17:18 2017 +0900

    Fix a unit test of GradientNoise

commit 94aad49
Merge: 348a8b9 f6a9222
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Dec 25 14:58:09 2017 +0900

    Merge branch 'master' into optimizer-hooks-3

commit 348a8b9
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Dec 18 14:54:49 2017 +0900

    Add unit tests for gradient clipping

commit ea172d2
Merge: 4c732eb 687593f
Author: Kenta OONO <oono@preferred.jp>
Date:   Mon Dec 18 11:39:23 2017 +0900

    Merge branch 'master' into optimizer-hooks-3

commit 4c732eb
Author: Kenta OONO <oono@preferred.jp>
Date:   Tue Nov 28 15:12:27 2017 +0900

    Fix unit tests of GradientNoise

commit a621d05
Author: Kenta OONO <oono@preferred.jp>
Date:   Tue Nov 28 14:39:53 2017 +0900

    Remove npz file

commit cfa22b1
Author: Kenta OONO <oono@preferred.jp>
Date:   Tue Nov 28 14:38:40 2017 +0900

    Take over #2943
  • Loading branch information
kmaehashi committed Mar 13, 2018
1 parent 50cc4cf commit fc06976
Show file tree
Hide file tree
Showing 21 changed files with 700 additions and 472 deletions.
255 changes: 30 additions & 225 deletions chainer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,11 @@

from chainer.backends import cuda
from chainer import link as link_module
from chainer import optimizer_hooks
from chainer import serializer as serializer_module
from chainer import variable


def _sum_sqnorm(arr):
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
return sum([float(i) for i in six.itervalues(sq_sum)])


def exponential_decay_noise(xp, shape, dtype, hook, opt):
"""Time-dependent annealed Gaussian noise function from the paper:
`Adding Gradient Noise Improves Learning for Very Deep Networks
<https://arxiv.org/pdf/1511.06807>`_.
"""
std = numpy.sqrt(hook.eta / numpy.power(1 + opt.t, 0.55))
return xp.random.normal(0, std, shape).astype(dtype)


class Hyperparameter(object):

"""Set of hyperparameter entries of an optimizer.
Expand Down Expand Up @@ -747,223 +728,47 @@ def __set__(self, obj, value):
setattr(obj.hyperparam, self._attr_name, value)


class WeightDecay(object):

"""Optimizer/UpdateRule hook function for weight decay regularization.
This hook function adds a scaled parameter to the corresponding gradient.
It can be used as a regularization.
Args:
rate (float): Coefficient for the weight decay.
Attributes:
~WeightDecay.rate (float): Coefficient for the weight decay.
~WeightDecay.timing (string): Specifies when this hook should be called
by the Optimizer/UpdateRule. Valid values are 'pre'
(before any updates) and 'post' (after any updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'WeightDecay'
call_for_each_param = True
timing = 'pre'

def __init__(self, rate):
self.rate = rate

def __call__(self, rule, param):
p, g = param.data, param.grad
if p is None or g is None:
return
with cuda.get_device_from_array(p) as dev:
if int(dev) == -1:
g += self.rate * p
else:
kernel = cuda.elementwise(
'T p, T decay', 'T g', 'g += decay * p', 'weight_decay')
kernel(p, self.rate, g)


class Lasso(object):
"""Optimizer/UpdateRule hook function for Lasso regularization.
This hook function adds a scaled parameter to the sign of each weight.
It can be used as a regularization.
Args:
rate (float): Coefficient for the weight decay.
Attributes:
~Lasso.rate (float): Coefficient for the weight decay.
~Lasso.timing (string): Specifies when this hook should be called by
the Optimizer/UpdateRule. Valid values are 'pre'
(before any updates) and 'post' (after any updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'Lasso'
call_for_each_param = True
timing = 'pre'

def __init__(self, rate):
self.rate = rate

def __call__(self, rule, param):
p, g = param.data, param.grad
if p is None or g is None:
return
xp = cuda.get_array_module(p)
with cuda.get_device_from_array(p) as dev:
sign = xp.sign(p)
if int(dev) == -1:
g += self.rate * sign
else:
kernel = cuda.elementwise(
'T s, T decay', 'T g', 'g += decay * s', 'lasso')
kernel(sign, self.rate, g)

def make_deprecation_message(module_name):
return ('chainer.optimizer.{0} is deprecated from v4. '
'Use chainer.optimizer_hooks.{0} instead.'
''.format(module_name))

class GradientClipping(object):
"""Optimizer hook function for gradient clipping.

This hook function scales all gradient arrays to fit to the defined L2 norm
threshold.
Args:
threshold (float): L2 norm threshold.
class WeightDecay(optimizer_hooks.WeightDecay):

Attributes:
~GradientClipping.threshold (float): L2 norm threshold of gradient
norm.
~GradientClipping.timing (string): Specifies when this hook should be
called by the Optimizer/UpdateRule. Valid values are
'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientClipping'
timing = 'pre'
def __init__(self, *args, **kwargs):
warnings.warn(make_deprecation_message('WeightDecay'),
DeprecationWarning)
return super(WeightDecay, self).__init__(*args, **kwargs)

def __init__(self, threshold):
self.threshold = threshold

def __call__(self, opt):
norm = numpy.sqrt(_sum_sqnorm(
[p.grad for p in opt.target.params(False)]))
rate = self.threshold / norm
if rate < 1:
for param in opt.target.params(False):
grad = param.grad
with cuda.get_device_from_array(grad):
grad *= rate
class Lasso(optimizer_hooks.Lasso):

def __init__(self, *args, **kwargs):
warnings.warn(make_deprecation_message('Lasso'),
DeprecationWarning)
return super(Lasso, self).__init__(*args, **kwargs)

class GradientNoise(object):
"""Optimizer/UpdateRule hook function for adding gradient noise.

This hook function simply adds noise generated by the ``noise_func``
to the gradient. By default it adds time-dependent annealed Gaussian
noise to the gradient at every training step:
class GradientClipping(optimizer_hooks.GradientClipping):

.. math::
def __init__(self, *args, **kwargs):
warnings.warn(make_deprecation_message('GradientClipping'),
DeprecationWarning)
return super(GradientClipping, self).__init__(*args, **kwargs)

g_t \\leftarrow g_t + N(0, \\sigma_t^2)

where
class GradientNoise(optimizer_hooks.GradientNoise):

.. math::
def __init__(self, *args, **kwargs):
warnings.warn(make_deprecation_message('GradientNoise'),
DeprecationWarning)
return super(GradientNoise, self).__init__(*args, **kwargs)

\\sigma_t^2 = \\frac{\\eta}{(1+t)^\\gamma}
with :math:`\\eta` selected from {0.01, 0.3, 1.0} and
:math:`\\gamma = 0.55`.
Args:
eta (float): Parameter that defines the scale of the noise, which for
the default noise function is recommended to be either 0.01, 0.3
or 1.0.
noise_func (function): Noise generating function which by default
is given by `Adding Gradient Noise Improves Learning for Very Deep\
Networks <https://arxiv.org/pdf/1511.06807>`_.
Attributes:
~GradientNoise.timing (string): Specifies when this hook should be
called by the Optimizer/UpdateRule. Valid values are
'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientNoise'
call_for_each_param = True
timing = 'pre'

def __init__(self, eta, noise_func=exponential_decay_noise):
self.eta = eta
self.noise_func = noise_func
class GradientHardClipping(optimizer_hooks.GradientHardClipping):

def __call__(self, rule, param):
g = param.grad
if g is None:
return
xp = cuda.get_array_module(g)
with cuda.get_device_from_array(g) as dev:
noise = self.noise_func(xp, g.shape, g.dtype, self, rule)
if int(dev) == -1:
g += noise
else:
kernel = cuda.elementwise(
'T noise', 'T g', 'g += noise', 'gradient_noise')
kernel(noise, g)


class GradientHardClipping(object):

"""Optimizer/UpdateRule hook function for gradient clipping.
This hook function clips all gradient arrays to be within a lower and upper
bound.
Args:
lower_bound (float): The lower bound of the gradient value.
upper_bound (float): The upper bound of the gradient value.
Attributes:
~GradientHardClipping.lower_bound (float): The lower bound of the
gradient value.
~GradientHardClipping.upper_bound (float): The upper bound of the
gradient value.
~GradientHardClipping.timing (string): Specifies when this hook should
be called by the Optimizer/UpdateRule. Valid values
are 'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientHardClipping'
call_for_each_param = True
timing = 'pre'

def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def __call__(self, rule, param):
grad = param.grad
if grad is None:
return
xp = cuda.get_array_module(grad)
with cuda.get_device_from_array(grad):
xp.clip(grad, self.lower_bound, self.upper_bound, out=grad)
def __init__(self, *args, **kwargs):
warnings.warn(make_deprecation_message('GradientHardClipping'),
DeprecationWarning)
return super(GradientHardClipping, self).__init__(*args, **kwargs)
5 changes: 5 additions & 0 deletions chainer/optimizer_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from chainer.optimizer_hooks.gradient_clipping import GradientClipping # NOQA
from chainer.optimizer_hooks.gradient_hard_clipping import GradientHardClipping # NOQA
from chainer.optimizer_hooks.gradient_noise import GradientNoise # NOQA
from chainer.optimizer_hooks.lasso import Lasso # NOQA
from chainer.optimizer_hooks.weight_decay import WeightDecay # NOQA
55 changes: 55 additions & 0 deletions chainer/optimizer_hooks/gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import collections

import numpy
import six

from chainer import cuda


def _sum_sqnorm(arr):
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
return sum([float(i) for i in six.itervalues(sq_sum)])


class GradientClipping(object):
"""Optimizer hook function for gradient clipping.
This hook function scales all gradient arrays to fit to the defined L2 norm
threshold.
Args:
threshold (float): L2 norm threshold.
Attributes:
~optimizer_hooks.GradientClipping.threshold (float): L2
norm threshold of gradient norm.
~optimizer_hooks.GradientClipping.timing (string): Specifies
when this hook should be
called by the Optimizer/UpdateRule. Valid values are
'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientClipping'
timing = 'pre'

def __init__(self, threshold):
self.threshold = threshold

def __call__(self, opt):
norm = numpy.sqrt(_sum_sqnorm(
[p.grad for p in opt.target.params(False)]))
rate = self.threshold / norm
if rate < 1:
for param in opt.target.params(False):
grad = param.grad
with cuda.get_device_from_array(grad):
grad *= rate
44 changes: 44 additions & 0 deletions chainer/optimizer_hooks/gradient_hard_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from chainer import cuda


class GradientHardClipping(object):

"""Optimizer/UpdateRule hook function for gradient clipping.
This hook function clips all gradient arrays to be within a lower and upper
bound.
Args:
lower_bound (float): The lower bound of the gradient value.
upper_bound (float): The upper bound of the gradient value.
Attributes:
~optimizer_hooks.GradientHardClipping.lower_bound (float): The
lower bound of the gradient value.
~optimizer_hooks.GradientHardClipping.upper_bound (float): The
upper bound of the gradient value.
~optimizer_hooks.GradientHardClipping.timing (string): Specifies
when this hook should be called by the
Optimizer/UpdateRule. Valid values are 'pre'
(before any updates) and 'post'
(after any updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientHardClipping'
call_for_each_param = True
timing = 'pre'

def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound

def __call__(self, rule, param):
grad = param.grad
if grad is None:
return
xp = cuda.get_array_module(grad)
with cuda.get_device_from_array(grad):
xp.clip(grad, self.lower_bound, self.upper_bound, out=grad)
Loading

0 comments on commit fc06976

Please sign in to comment.