/
clip.py
83 lines (61 loc) · 2.33 KB
/
clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy
from chainer.backends import cuda
from chainer import function_node
from chainer import utils
from chainer.utils import type_check
class Clip(function_node.FunctionNode):
"""Clips (limits) elements of input variable."""
def __init__(self, x_min, x_max):
# x_min must be less than x_max.
if x_min >= x_max:
raise ValueError('x_min must be less than x_max.')
self.x_min = x_min
self.x_max = x_max
def check_type_forward(self, in_types):
type_check._argname(in_types, ('x',))
x_type, = in_types
type_check.expect(x_type.dtype.kind == 'f')
def forward_cpu(self, inputs):
self.retain_inputs((0,))
x, = inputs
return utils.force_array(
numpy.clip(x, self.x_min, self.x_max),
x.dtype),
def forward_gpu(self, x):
self.retain_inputs((0,))
return cuda.cupy.clip(x[0], self.x_min, self.x_max),
def backward(self, indexes, grad_outputs):
x, = self.get_retained_inputs()
return ClipGrad(x.data, self.x_min, self.x_max).apply(grad_outputs)
class ClipGrad(function_node.FunctionNode):
def __init__(self, x, x_min, x_max):
self.cond = (x_min <= x) * (x <= x_max)
def check_type_forward(self, in_types):
type_check._argname(in_types, ('gy',))
type_check.expect(in_types[0].dtype.kind == 'f')
def forward_cpu(self, inputs):
gy, = inputs
gx = utils.force_array(gy * self.cond, gy.dtype)
return gx,
def forward_gpu(self, inputs):
gx = cuda.elementwise(
'T gy, bool cond', 'T gx',
'gx = cond ? gy : T(0)',
'clip_bwd')(inputs[0], self.cond)
return gx,
def backward(self, indexes, grad_outputs):
return grad_outputs[0] * self.cond,
def clip(x, x_min, x_max):
"""Clips (limits) elements of input variable.
Given an interval ``[x_min, xmax]``, elements outside the interval are
clipped to the interval edges.
Its gradients at ``x_min`` and ``x_max`` are regarded as 1.
Args:
x (:class:`~chainer.Variable` or :ref:`ndarray`):
Input variable to be clipped.
x_min (float): Minimum value.
x_max (float): Maximum value.
Returns:
~chainer.Variable: Output variable.
"""
return Clip(x_min, x_max).apply((x,))[0]