/
gradient_clipping.py
55 lines (42 loc) · 1.55 KB
/
gradient_clipping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import collections
import numpy
import six
from chainer import cuda
def _sum_sqnorm(arr):
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
return sum([float(i) for i in six.itervalues(sq_sum)])
class GradientClipping(object):
"""Optimizer hook function for gradient clipping.
This hook function scales all gradient arrays to fit to the defined L2 norm
threshold.
Args:
threshold (float): L2 norm threshold.
Attributes:
~optimizer_hooks.GradientClipping.threshold (float): L2
norm threshold of gradient norm.
~optimizer_hooks.GradientClipping.timing (string): Specifies
when this hook should be
called by the Optimizer/UpdateRule. Valid values are
'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientClipping'
timing = 'pre'
def __init__(self, threshold):
self.threshold = threshold
def __call__(self, opt):
norm = numpy.sqrt(_sum_sqnorm(
[p.grad for p in opt.target.params(False)]))
rate = self.threshold / norm
if rate < 1:
for param in opt.target.params(False):
grad = param.grad
with cuda.get_device_from_array(grad):
grad *= rate