-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Squash pull request #3977 from delta2323/optimizer-hooks-3
Squashed commit of the following: commit 6ac796b Merge: 50cc4cf e7e418d Author: Ryosuke Okuta <okuta@preferred.jp> Date: Thu Mar 8 14:14:29 2018 +0900 Merge pull request #3977 from delta2323/optimizer-hooks-3 Create `chainer.optimizer_hooks` namespace and move hooks there. commit e7e418d Author: Kenta Oono <oono@preferred.jp> Date: Wed Mar 7 17:34:16 2018 +0900 Fix documents of optimizer hooks commit 880e5eb Author: Kenta OONO <oono@preferred.jp> Date: Sat Mar 3 23:20:57 2018 +0900 Remove unused import commit 0d6b270 Merge: ec0d7e4 b94cd47 Author: Kenta OONO <oono@preferred.jp> Date: Sat Mar 3 23:20:00 2018 +0900 Merge branch 'master' into optimizer-hooks-3 commit ec0d7e4 Merge: 588a7aa 0be1b38 Author: Kenta OONO <oono@preferred.jp> Date: Mon Feb 26 01:01:22 2018 +0900 Merge branch 'master' into optimizer-hooks-3 commit 588a7aa Author: Kenta OONO <oono@preferred.jp> Date: Sun Feb 25 21:59:36 2018 +0900 Apply changes in the current master branch to GradientClipping commit bc3492c Author: Kenta OONO <oono@preferred.jp> Date: Sun Feb 25 21:49:52 2018 +0900 s/optimizer/optimizer/hooks/ in guides commit 026015d Merge: f574890 72be88b Author: Kenta OONO <oono@preferred.jp> Date: Sun Feb 25 21:46:16 2018 +0900 Merge branch 'master' into optimizer-hooks-3 commit f574890 Author: Kenta OONO <oono@preferred.jp> Date: Mon Dec 25 15:38:09 2017 +0900 Fix tests for deprecated optimizer hooks commit 38b2b40 Author: Kenta OONO <oono@preferred.jp> Date: Mon Dec 25 15:17:18 2017 +0900 Fix a unit test of GradientNoise commit 94aad49 Merge: 348a8b9 f6a9222 Author: Kenta OONO <oono@preferred.jp> Date: Mon Dec 25 14:58:09 2017 +0900 Merge branch 'master' into optimizer-hooks-3 commit 348a8b9 Author: Kenta OONO <oono@preferred.jp> Date: Mon Dec 18 14:54:49 2017 +0900 Add unit tests for gradient clipping commit ea172d2 Merge: 4c732eb 687593f Author: Kenta OONO <oono@preferred.jp> Date: Mon Dec 18 11:39:23 2017 +0900 Merge branch 'master' into optimizer-hooks-3 commit 4c732eb Author: Kenta OONO <oono@preferred.jp> Date: Tue Nov 28 15:12:27 2017 +0900 Fix unit tests of GradientNoise commit a621d05 Author: Kenta OONO <oono@preferred.jp> Date: Tue Nov 28 14:39:53 2017 +0900 Remove npz file commit cfa22b1 Author: Kenta OONO <oono@preferred.jp> Date: Tue Nov 28 14:38:40 2017 +0900 Take over #2943
- Loading branch information
Showing
21 changed files
with
700 additions
and
472 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from chainer.optimizer_hooks.gradient_clipping import GradientClipping # NOQA | ||
from chainer.optimizer_hooks.gradient_hard_clipping import GradientHardClipping # NOQA | ||
from chainer.optimizer_hooks.gradient_noise import GradientNoise # NOQA | ||
from chainer.optimizer_hooks.lasso import Lasso # NOQA | ||
from chainer.optimizer_hooks.weight_decay import WeightDecay # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import collections | ||
|
||
import numpy | ||
import six | ||
|
||
from chainer import cuda | ||
|
||
|
||
def _sum_sqnorm(arr): | ||
sq_sum = collections.defaultdict(float) | ||
for x in arr: | ||
with cuda.get_device_from_array(x) as dev: | ||
x = x.ravel() | ||
s = x.dot(x) | ||
sq_sum[int(dev)] += s | ||
return sum([float(i) for i in six.itervalues(sq_sum)]) | ||
|
||
|
||
class GradientClipping(object): | ||
"""Optimizer hook function for gradient clipping. | ||
This hook function scales all gradient arrays to fit to the defined L2 norm | ||
threshold. | ||
Args: | ||
threshold (float): L2 norm threshold. | ||
Attributes: | ||
~optimizer_hooks.GradientClipping.threshold (float): L2 | ||
norm threshold of gradient norm. | ||
~optimizer_hooks.GradientClipping.timing (string): Specifies | ||
when this hook should be | ||
called by the Optimizer/UpdateRule. Valid values are | ||
'pre' (before any updates) and 'post' (after any | ||
updates). | ||
.. versionadded:: 4.0.0 | ||
The *timing* parameter. | ||
""" | ||
name = 'GradientClipping' | ||
timing = 'pre' | ||
|
||
def __init__(self, threshold): | ||
self.threshold = threshold | ||
|
||
def __call__(self, opt): | ||
norm = numpy.sqrt(_sum_sqnorm( | ||
[p.grad for p in opt.target.params(False)])) | ||
rate = self.threshold / norm | ||
if rate < 1: | ||
for param in opt.target.params(False): | ||
grad = param.grad | ||
with cuda.get_device_from_array(grad): | ||
grad *= rate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from chainer import cuda | ||
|
||
|
||
class GradientHardClipping(object): | ||
|
||
"""Optimizer/UpdateRule hook function for gradient clipping. | ||
This hook function clips all gradient arrays to be within a lower and upper | ||
bound. | ||
Args: | ||
lower_bound (float): The lower bound of the gradient value. | ||
upper_bound (float): The upper bound of the gradient value. | ||
Attributes: | ||
~optimizer_hooks.GradientHardClipping.lower_bound (float): The | ||
lower bound of the gradient value. | ||
~optimizer_hooks.GradientHardClipping.upper_bound (float): The | ||
upper bound of the gradient value. | ||
~optimizer_hooks.GradientHardClipping.timing (string): Specifies | ||
when this hook should be called by the | ||
Optimizer/UpdateRule. Valid values are 'pre' | ||
(before any updates) and 'post' | ||
(after any updates). | ||
.. versionadded:: 4.0.0 | ||
The *timing* parameter. | ||
""" | ||
name = 'GradientHardClipping' | ||
call_for_each_param = True | ||
timing = 'pre' | ||
|
||
def __init__(self, lower_bound, upper_bound): | ||
self.lower_bound = lower_bound | ||
self.upper_bound = upper_bound | ||
|
||
def __call__(self, rule, param): | ||
grad = param.grad | ||
if grad is None: | ||
return | ||
xp = cuda.get_array_module(grad) | ||
with cuda.get_device_from_array(grad): | ||
xp.clip(grad, self.lower_bound, self.upper_bound, out=grad) |
Oops, something went wrong.