/
gradient_clipping.py
66 lines (53 loc) · 2.06 KB
/
gradient_clipping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import collections
import six
from chainer import cuda
def _sum_sqnorm(arr):
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
# If only a single device is used, aggregate square norms on it.
if len(sq_sum) == 1:
with cuda.get_device_from_array(arr[0]):
return sum(six.itervalues(sq_sum))
else:
return sum([float(i) for i in six.itervalues(sq_sum)])
class GradientClipping(object):
"""Optimizer hook function for gradient clipping.
This hook function scales all gradient arrays to fit to the defined L2 norm
threshold.
Args:
threshold (float): L2 norm threshold.
Attributes:
~optimizer_hooks.GradientClipping.threshold (float): L2
norm threshold of gradient norm.
~optimizer_hooks.GradientClipping.timing (string): Specifies
when this hook should be
called by the Optimizer/UpdateRule. Valid values are
'pre' (before any updates) and 'post' (after any
updates).
.. versionadded:: 4.0.0
The *timing* parameter.
"""
name = 'GradientClipping'
timing = 'pre'
def __init__(self, threshold):
self.threshold = threshold
def __call__(self, opt):
sqnorm = _sum_sqnorm([p.grad for p in opt.target.params(False)])
with cuda.get_device_from_array(sqnorm) as dev:
norm = cuda.get_array_module(sqnorm).sqrt(sqnorm)
rate = self.threshold / norm
# When no clipping is needed, skip the clipping on CPU and
# multiply 1.0 on the device otherwise.
if int(dev) == -1:
if rate >= 1:
return
else:
rate = rate.clip(None, 1)
for param in opt.target.params(False):
grad = param.grad
with cuda.get_device_from_array(grad):
grad *= rate