/
relu.py
142 lines (105 loc) · 3.92 KB
/
relu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy
import chainer
from chainer.backends import cuda
from chainer import function_node
from chainer import utils
from chainer.utils import type_check
if cuda.cudnn_enabled:
cudnn = cuda.cudnn
_mode = cuda.cuda.cudnn.CUDNN_ACTIVATION_RELU
class ReLU(function_node.FunctionNode):
"""Rectified Linear Unit."""
_use_cudnn = False
def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
in_types[0].dtype.kind == 'f',
)
def forward_cpu(self, x):
self.retain_outputs((0,))
return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)),
def forward_gpu(self, x):
if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous:
# cupy.activation_backward requires the input.
# So, we retain it for backward computation.
self.retain_inputs((0,))
self._use_cudnn = True
y = cudnn.activation_forward(x[0], _mode)
else:
y = cuda.cupy.maximum(x[0], 0)
self.retain_outputs((0,))
return y,
def backward(self, indexes, gy):
y = self.get_retained_outputs()[0]
if chainer.should_use_cudnn('==always') and self._use_cudnn:
x = self.get_retained_inputs()[0]
return ReLUGrad3(x, y).apply((gy[0],))
else:
return ReLUGrad2(y).apply((gy[0],))
def _heaviside(x):
return (x > 0).astype(x.dtype)
class ReLUGrad2(function_node.FunctionNode):
"""Computes the gradient of the ReLU function.
This function takes 2 variables b and c, and
computes f(b, c) = sign(b) * c with backpropagation
where operations are done in elementwise manner
and sign(x) = 1 when x > 0 is positive and 0 otherwise.
As the gradient of f with respect to b is 0,
we do not backpropagate errors toward b for computational efficiency.
"""
def __init__(self, b):
super(ReLUGrad2, self).__init__()
self.b = b.data
def forward_cpu(self, inputs):
y = (self.b > 0) * inputs[0]
return utils.force_array(y, dtype=y.dtype),
def forward_gpu(self, inputs):
gx = cuda.elementwise(
'T y, T gy', 'T gx',
'gx = y > 0 ? gy : (T)0',
'relu_bwd')(self.b, inputs[0])
return gx,
def backward(self, indexes, gy):
return gy[0] * _heaviside(self.b),
class ReLUGrad3(function_node.FunctionNode):
"""Computes the gradient of the ReLU function.
This function takes 3 variables a, b, and c, and
computes f(a, b, c) = sign(b) * c with backpropagation
where operations are dones in elementwise manner
and sign(x) = 1 if x > 0 is positive and 0 otherwise.
As the gradient of f with respect to a and b are 0,
we do not backpropagate errors toward them for computational efficiency.
"""
def __init__(self, a, b):
super(ReLUGrad3, self).__init__()
self.a = a.data
self.b = b.data
def forward_cpu(self, inputs):
return (self.b > 0) * inputs[0],
def forward_gpu(self, inputs):
assert chainer.should_use_cudnn('==always')
return cudnn.activation_backward(self.a, self.b, inputs[0], _mode),
def backward(self, indexes, gy):
return gy[0] * _heaviside(self.b),
def relu(x):
"""Rectified Linear Unit function.
.. math:: f(x)=\\max(0, x).
Args:
x (:class:`~chainer.Variable` or :class:`numpy.ndarray` or \
:class:`cupy.ndarray`):
Input variable. A :math:`(s_1, s_2, ..., s_N)`-shaped float array.
Returns:
~chainer.Variable: Output variable. A
:math:`(s_1, s_2, ..., s_N)`-shaped float array.
.. admonition:: Example
>>> x = np.array([[-1, 0], [2, -3], [-2, 1]], 'f')
>>> np.any(x < 0)
True
>>> y = F.relu(x)
>>> np.any(y.data < 0)
False
>>> y.shape
(3, 2)
"""
y, = ReLU().apply((x,))
return y