diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index e41b357f7369..d830017f64ce 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -2,7 +2,7 @@ import chainer from chainer import cuda -from chainer import function +from chainer import function_node from chainer import utils from chainer.utils import type_check @@ -12,10 +12,11 @@ _mode = cudnn.cudnn.CUDNN_ACTIVATION_RELU -class ReLU(function.Function): +class ReLU(function_node.FunctionNode): """Rectified Linear Unit.""" - # TODO(beam2d): Implement in-place version. + + _use_cudnn = False def check_type_forward(self, in_types): type_check.expect( @@ -24,36 +25,94 @@ def check_type_forward(self, in_types): ) def forward_cpu(self, x): - self.retain_inputs(()) self.retain_outputs((0,)) return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)), def forward_gpu(self, x): if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous: + # cupy.activation_backward requires the input. + # So, we retain it for backward computation. + self.retain_inputs((0,)) self._use_cudnn = True y = cudnn.activation_forward(x[0], _mode) else: - self.retain_inputs(()) - self._use_cudnn = False y = cuda.cupy.maximum(x[0], 0) self.retain_outputs((0,)) return y, - def backward_cpu(self, x, gy): - y = self.output_data[0] - return utils.force_array(gy[0] * (y > 0)), - - def backward_gpu(self, x, gy): - y = self.output_data[0] + def backward(self, indexes, gy): + y = self.get_retained_outputs()[0] if chainer.should_use_cudnn('==always') and self._use_cudnn: - gx = cudnn.activation_backward(x[0], y, gy[0], _mode) + x = self.get_retained_inputs()[0] + return ReLUGrad3(x, y).apply((gy[0],)) else: - gx = cuda.elementwise( - 'T y, T gy', 'T gx', - 'gx = y > 0 ? gy : (T)0', - 'relu_bwd')(y, gy[0]) + return ReLUGrad2(y).apply((gy[0],)) + + +def _heaviside(x): + return (x > 0).astype(x.dtype) + + +class ReLUGrad2(function_node.FunctionNode): + """Computes the gradient of the ReLU function. + + This function takes 2 variables b and c, and + computes f(b, c) = sign(b) * c with backpropagation + where operations are dones in elementwise manner + and sign(x) = 1 when x > 0 is positive and 0 otherwise. + + As the gradient of f with respect to b is 0, + we do not backpropagate errors toward b for computational efficiency. + """ + + def __init__(self, b): + super(ReLUGrad2).__init__() + self.b = b.data + + def forward_cpu(self, inputs): + y = (self.b > 0) * inputs[0] + return utils.force_array(y, dtype=y.dtype), + + def forward_gpu(self, inputs): + b = cuda.to_gpu(self.b) + gx = cuda.elementwise( + 'T y, T gy', 'T gx', + 'gx = y > 0 ? gy : (T)0', + 'relu_bwd')(b, inputs[0]) return gx, + def backward(self, indexes, gy): + return gy[0] * _heaviside(self.b), + + +class ReLUGrad3(function_node.FunctionNode): + """Computes the gradient of the ReLU function. + + This function takes 3 variables a, b, and c, and + computes f(a, b, c) = sign(b) * c with backpropagation + where operations are dones in elementwise manner + and sign(x) = 1 if x > 0 is positive and 0 otherwise. + + As the gradient of f with respect to a and b are 0, + we do not backpropagate errors toward them for computational efficiency. + """ + + def __init__(self, a, b): + self.a = a.data + self.b = b.data + + def forward_cpu(self, inputs): + return (self.b > 0) * inputs[0], + + def forward_gpu(self, inputs): + a = cuda.to_gpu(self.a) + b = cuda.to_gpu(self.b) + assert chainer.should_use_cudnn('==always') + return cudnn.activation_backward(a, b, inputs[0], _mode), + + def backward(self, indexes, gy): + return gy[0] * _heaviside(self.b), + def relu(x): """Rectified Linear Unit function. @@ -81,4 +140,5 @@ def relu(x): (3, 2) """ - return ReLU()(x) + y, = ReLU().apply((x,)) + return y diff --git a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py index e8abe0be945e..7784daededcd 100644 --- a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py +++ b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py @@ -25,6 +25,7 @@ def setUp(self): if -0.1 < self.x[i] < 0.1: self.x[i] = 0.5 self.gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) self.check_backward_options = {} if self.dtype == numpy.float16: self.check_backward_options = {'dtype': numpy.float64} @@ -60,7 +61,7 @@ def test_forward_gpu_no_cudnn(self): def check_backward(self, x_data, y_grad, use_cudnn='always'): with chainer.using_config('use_cudnn', use_cudnn): gradient_check.check_backward( - functions.ReLU(), x_data, y_grad, + functions.relu, x_data, y_grad, **self.check_backward_options) @condition.retry(3) @@ -83,6 +84,44 @@ def test_backward_gpu_non_contiguous(self): def test_backward_cpu_no_cudnn(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never') + def check_double_backward(self, x_data, y_grad, x_grad_grad, + use_cudnn='always'): + def f(x): + x = functions.relu(x) + return x * x + + with chainer.using_config('use_cudnn', use_cudnn): + gradient_check.check_double_backward( + f, x_data, y_grad, x_grad_grad, + **self.check_backward_options) + + @condition.retry(1) + def test_double_backward_cpu(self): + self.check_double_backward(self.x, self.gy, self.ggx) + + @attr.gpu + @condition.retry(1) + def test_double_backward_gpu(self): + self.check_double_backward(cuda.to_gpu(self.x), + cuda.to_gpu(self.gy), + cuda.to_gpu(self.ggx)) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_non_contiguous(self): + self.check_double_backward( + cuda.cupy.asfortranarray(cuda.to_gpu(self.x)), + cuda.cupy.asfortranarray(cuda.to_gpu(self.gy)), + cuda.cupy.asfortranarray(cuda.to_gpu(self.ggx))) + + @attr.gpu + @condition.retry(3) + def test_double_backward_cpu_no_cudnn(self): + self.check_double_backward(cuda.to_gpu(self.x), + cuda.to_gpu(self.gy), + cuda.to_gpu(self.ggx), + 'never') + @testing.parameterize(*testing.product({ 'use_cudnn': ['always', 'auto', 'never'],