From f777e1e5a2afee89775dab3c4a6cd05df946cb56 Mon Sep 17 00:00:00 2001 From: Kenta OONO Date: Thu, 17 Aug 2017 14:52:25 +0900 Subject: [PATCH 01/15] Apply new-style API to ReLU --- chainer/functions/activation/relu.py | 103 ++++++++++++++++++++++----- 1 file changed, 86 insertions(+), 17 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index e41b357f7369..b3c40f442961 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -2,7 +2,7 @@ import chainer from chainer import cuda -from chainer import function +from chainer import function_node from chainer import utils from chainer.utils import type_check @@ -12,7 +12,7 @@ _mode = cudnn.cudnn.CUDNN_ACTIVATION_RELU -class ReLU(function.Function): +class ReLU(function.FunctionNode): """Rectified Linear Unit.""" # TODO(beam2d): Implement in-place version. @@ -24,36 +24,105 @@ def check_type_forward(self, in_types): ) def forward_cpu(self, x): - self.retain_inputs(()) self.retain_outputs((0,)) return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)), def forward_gpu(self, x): if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous: + self.retain_inputs((0,)) self._use_cudnn = True y = cudnn.activation_forward(x[0], _mode) else: - self.retain_inputs(()) self._use_cudnn = False y = cuda.cupy.maximum(x[0], 0) - self.retain_outputs((0,)) + self.retrain_outputs((0,)) return y, - def backward_cpu(self, x, gy): - y = self.output_data[0] - return utils.force_array(gy[0] * (y > 0)), - - def backward_gpu(self, x, gy): - y = self.output_data[0] - if chainer.should_use_cudnn('==always') and self._use_cudnn: - gx = cudnn.activation_backward(x[0], y, gy[0], _mode) + def backward(self, indexes, gy): + x = self.get_retained_inputs() + y = self.get_retained_outputs()[0] + if x: + # The only case to use ReLUGrad3 is compute in GPU and use_cudnn is True. + return ReLUGrad3(self._use_cudnn).apply((x[0], y, gy[0])) else: - gx = cuda.elementwise( - 'T y, T gy', 'T gx', - 'gx = y > 0 ? gy : (T)0', - 'relu_bwd')(y, gy[0]) + return ReLUGrad2(self._use_cudnn).apply((y, gy[0])) + + +class _ReLUGradBase(function.FunctionNode): + + def __init__(self, use_cudnn): + super(ReLUGrad).__init__() + self._use_cudnn = use_cudnn + + def forward_cpu(self, inputs): + b, c = inputs + y = (b > 0) * c + self.get_retained_inputs((0,)) + self.get_retained_outputs((0,)) + return y, + + def backward_cpu(self, indexes, gy): + ret = [] + if 0 in indexes: + y = self.get_retained_outputs()[0] + gb = gy * y + ret.append(gb) + if 1 in indexes: + b = self.get_retained_inputs()[0] + gc = gy * (b > 0) + ret.append(gc) + return ret + + +class ReLUGrad2(_ReLUGradBase): + + def forward_gpu(self, inputs): + b, c = inputs + gx = cuda.elementwise( + 'T y, T gy', 'T gx', + 'gx = y > 0 ? gy : (T)0', + 'relu_bwd')(b, c) + self.retain_inputs((0,)) + self.retain_outputs((0,)) + return gx, + + def backward_gpu(self, indexes, gy): + ret = [] + if 0 in indexes: + y = self.get_retained_outputs()[0] + gb = gy * y + ret.append(gb) + if 1 in indexes: + b = self.get_retained_inputs()[0] + gc = gy * (b > 0) + ret.append(gc) + return ret + + +class ReLUGrad3(_ReLUGradBase): + + def forward_gpu(self): + assert chainer.should_use_cudnn('==always') and self._use_cudnn + a, b, c = inputs + gx = cudnn.activation_backward(a, b, c, _mode) + self.retain_inputs((1,)) + self.retain_outputs((0,)) return gx, + def backward_gpu(self, indexes, gy): + ret = [] + if 0 in indexes: + ret.append(None) + if 1 in indexes: + y = self.get_retained_outputs()[0] + gb = gy * y + ret.append(gb) + if 2 in indexes: + b = self.get_retained_inputs()[0] + gc = gy * (b > 0) + ret.append(gc) + return ret + def relu(x): """Rectified Linear Unit function. From b5e20bc33ca1f3249dcb72c873c01acccfd777fb Mon Sep 17 00:00:00 2001 From: Kenta OONO Date: Thu, 17 Aug 2017 16:53:21 +0900 Subject: [PATCH 02/15] Pass existing unit tests --- chainer/functions/activation/relu.py | 61 ++++++++++++------- .../activation_tests/test_relu.py | 2 +- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index b3c40f442961..8dec3098eac3 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -12,7 +12,7 @@ _mode = cudnn.cudnn.CUDNN_ACTIVATION_RELU -class ReLU(function.FunctionNode): +class ReLU(function_node.FunctionNode): """Rectified Linear Unit.""" # TODO(beam2d): Implement in-place version. @@ -25,6 +25,7 @@ def check_type_forward(self, in_types): def forward_cpu(self, x): self.retain_outputs((0,)) + self._use_cudnn = False return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)), def forward_gpu(self, x): @@ -39,27 +40,24 @@ def forward_gpu(self, x): return y, def backward(self, indexes, gy): - x = self.get_retained_inputs() y = self.get_retained_outputs()[0] - if x: - # The only case to use ReLUGrad3 is compute in GPU and use_cudnn is True. - return ReLUGrad3(self._use_cudnn).apply((x[0], y, gy[0])) + if chainer.should_use_cudnn('==always') and self._use_cudnn: + # The only case to use ReLUGrad3 is compute is done in GPU + # and _use_cudnn is True. + x = self.get_retained_inputs()[0] + return ReLUGrad3().apply((x, y, gy[0])) else: - return ReLUGrad2(self._use_cudnn).apply((y, gy[0])) + return ReLUGrad2().apply((y, gy[0])) -class _ReLUGradBase(function.FunctionNode): - - def __init__(self, use_cudnn): - super(ReLUGrad).__init__() - self._use_cudnn = use_cudnn +class ReLUGrad2(function_node.FunctionNode): def forward_cpu(self, inputs): b, c = inputs y = (b > 0) * c - self.get_retained_inputs((0,)) - self.get_retained_outputs((0,)) - return y, + self.retain_inputs((0,)) + self.retain_outputs((0,)) + return utils.force_array(y, dtype=y.dtype), def backward_cpu(self, indexes, gy): ret = [] @@ -73,9 +71,6 @@ def backward_cpu(self, indexes, gy): ret.append(gc) return ret - -class ReLUGrad2(_ReLUGradBase): - def forward_gpu(self, inputs): b, c = inputs gx = cuda.elementwise( @@ -99,12 +94,33 @@ def backward_gpu(self, indexes, gy): return ret -class ReLUGrad3(_ReLUGradBase): +class ReLUGrad3(function_node.FunctionNode): + + def forward_cpu(self, inputs): + b, c = inputs + y = (b > 0) * c + self.retain_inputs((0,)) + self.retain_outputs((0,)) + return y, + + def backward_cpu(self, indexes, gy): + ret = [] + if 0 in indexes: + ret.append(None) + if 1 in indexes: + y = self.get_retained_outputs()[0] + gb = gy * y + ret.append(gb) + if 2 in indexes: + b = self.get_retained_inputs()[0] + gc = gy * (b > 0) + ret.append(gc) + return ret - def forward_gpu(self): - assert chainer.should_use_cudnn('==always') and self._use_cudnn + def forward_gpu(self, inputs): a, b, c = inputs - gx = cudnn.activation_backward(a, b, c, _mode) + assert chainer.should_use_cudnn('==always') + y = cudnn.activation_backward(a, b, c, _mode) self.retain_inputs((1,)) self.retain_outputs((0,)) return gx, @@ -150,4 +166,5 @@ def relu(x): (3, 2) """ - return ReLU()(x) + y, = ReLU().apply((x,)) + return y diff --git a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py index e8abe0be945e..d3bcaac787cb 100644 --- a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py +++ b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py @@ -60,7 +60,7 @@ def test_forward_gpu_no_cudnn(self): def check_backward(self, x_data, y_grad, use_cudnn='always'): with chainer.using_config('use_cudnn', use_cudnn): gradient_check.check_backward( - functions.ReLU(), x_data, y_grad, + functions.relu, x_data, y_grad, **self.check_backward_options) @condition.retry(3) From 6fcdfad75ea46334015055d195ea3ac5aa57eb4c Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Thu, 17 Aug 2017 17:00:54 +0900 Subject: [PATCH 03/15] Pass existing GPU tests --- chainer/functions/activation/relu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 8dec3098eac3..07771846cd85 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -36,7 +36,7 @@ def forward_gpu(self, x): else: self._use_cudnn = False y = cuda.cupy.maximum(x[0], 0) - self.retrain_outputs((0,)) + self.retain_outputs((0,)) return y, def backward(self, indexes, gy): @@ -123,7 +123,7 @@ def forward_gpu(self, inputs): y = cudnn.activation_backward(a, b, c, _mode) self.retain_inputs((1,)) self.retain_outputs((0,)) - return gx, + return y, def backward_gpu(self, indexes, gy): ret = [] From 31f9acc4d889667469f87857ab4e020e60943c7c Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Fri, 18 Aug 2017 13:15:03 +0900 Subject: [PATCH 04/15] Fix relu to pass unit tests --- chainer/functions/activation/relu.py | 67 ++++++++----------- .../activation_tests/test_relu.py | 21 ++++++ 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 07771846cd85..8802cad681a0 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -50,27 +50,35 @@ def backward(self, indexes, gy): return ReLUGrad2().apply((y, gy[0])) +class Zero(function_node.FunctionNode): + + def forward(self, inputs): + xp = chainer.cuda.get_array_module(*inputs) + return utils.force_array(xp.zeros_like(inputs[0])), + + def backward(self, indexes, gy): + return Zero().apply(gy) + + +class Heaviside(function_node.FunctionNode): + + def forward(self, inputs): + x, = inputs + self.retain_outputs((0,)) + return utils.force_array((x > 0).astype(x.dtype)), + + def backward(self, indexes, gy): + return Zero().apply(gy) + + class ReLUGrad2(function_node.FunctionNode): def forward_cpu(self, inputs): b, c = inputs y = (b > 0) * c self.retain_inputs((0,)) - self.retain_outputs((0,)) return utils.force_array(y, dtype=y.dtype), - def backward_cpu(self, indexes, gy): - ret = [] - if 0 in indexes: - y = self.get_retained_outputs()[0] - gb = gy * y - ret.append(gb) - if 1 in indexes: - b = self.get_retained_inputs()[0] - gc = gy * (b > 0) - ret.append(gc) - return ret - def forward_gpu(self, inputs): b, c = inputs gx = cuda.elementwise( @@ -78,18 +86,16 @@ def forward_gpu(self, inputs): 'gx = y > 0 ? gy : (T)0', 'relu_bwd')(b, c) self.retain_inputs((0,)) - self.retain_outputs((0,)) return gx, - def backward_gpu(self, indexes, gy): + def backward(self, indexes, gy): ret = [] if 0 in indexes: - y = self.get_retained_outputs()[0] - gb = gy * y + gb = Zero().apply(gy)[0] ret.append(gb) if 1 in indexes: b = self.get_retained_inputs()[0] - gc = gy * (b > 0) + gc = gy[0] * Heaviside().apply((b,))[0] ret.append(gc) return ret @@ -97,45 +103,28 @@ def backward_gpu(self, indexes, gy): class ReLUGrad3(function_node.FunctionNode): def forward_cpu(self, inputs): - b, c = inputs + a, b, c = inputs y = (b > 0) * c self.retain_inputs((0,)) - self.retain_outputs((0,)) return y, - def backward_cpu(self, indexes, gy): - ret = [] - if 0 in indexes: - ret.append(None) - if 1 in indexes: - y = self.get_retained_outputs()[0] - gb = gy * y - ret.append(gb) - if 2 in indexes: - b = self.get_retained_inputs()[0] - gc = gy * (b > 0) - ret.append(gc) - return ret - def forward_gpu(self, inputs): a, b, c = inputs assert chainer.should_use_cudnn('==always') y = cudnn.activation_backward(a, b, c, _mode) self.retain_inputs((1,)) - self.retain_outputs((0,)) return y, - def backward_gpu(self, indexes, gy): + def backward(self, indexes, gy): ret = [] if 0 in indexes: ret.append(None) if 1 in indexes: - y = self.get_retained_outputs()[0] - gb = gy * y + gb = Zero().apply(gy)[0] ret.append(gb) if 2 in indexes: b = self.get_retained_inputs()[0] - gc = gy * (b > 0) + gc = gy[0] * Heaviside().apply((b,))[0] ret.append(gc) return ret diff --git a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py index d3bcaac787cb..3ba6b98e85e0 100644 --- a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py +++ b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py @@ -25,6 +25,7 @@ def setUp(self): if -0.1 < self.x[i] < 0.1: self.x[i] = 0.5 self.gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) self.check_backward_options = {} if self.dtype == numpy.float16: self.check_backward_options = {'dtype': numpy.float64} @@ -83,6 +84,26 @@ def test_backward_gpu_non_contiguous(self): def test_backward_cpu_no_cudnn(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never') + def check_double_backward(self, x_data, y_grad, x_grad_grad): + + def f(x): + x = functions.relu(x) + return x * x + + gradient_check.check_double_backward(f, x_data, y_grad, x_grad_grad, + **self.check_backward_options) + + @condition.retry(1) + def test_double_backward_cpu(self): + self.check_double_backward(self.x, self.gy, self.ggx) + + @attr.gpu + @condition.retry(1) + def test_double_backward_gpu(self): + self.check_double_backward(cuda.to_gpu(self.x), + cuda.to_gpu(self.gy), + cuda.to_gpu(self.ggx)) + @testing.parameterize(*testing.product({ 'use_cudnn': ['always', 'auto', 'never'], From 7666ff6c83aa3e4f56fab4f7a6118cf93d09aad8 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Fri, 18 Aug 2017 13:28:12 +0900 Subject: [PATCH 05/15] Add more unit tests for F.relu --- .../activation_tests/test_relu.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py index 3ba6b98e85e0..7784daededcd 100644 --- a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py +++ b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py @@ -84,14 +84,16 @@ def test_backward_gpu_non_contiguous(self): def test_backward_cpu_no_cudnn(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never') - def check_double_backward(self, x_data, y_grad, x_grad_grad): - + def check_double_backward(self, x_data, y_grad, x_grad_grad, + use_cudnn='always'): def f(x): x = functions.relu(x) return x * x - gradient_check.check_double_backward(f, x_data, y_grad, x_grad_grad, - **self.check_backward_options) + with chainer.using_config('use_cudnn', use_cudnn): + gradient_check.check_double_backward( + f, x_data, y_grad, x_grad_grad, + **self.check_backward_options) @condition.retry(1) def test_double_backward_cpu(self): @@ -104,6 +106,22 @@ def test_double_backward_gpu(self): cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx)) + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_non_contiguous(self): + self.check_double_backward( + cuda.cupy.asfortranarray(cuda.to_gpu(self.x)), + cuda.cupy.asfortranarray(cuda.to_gpu(self.gy)), + cuda.cupy.asfortranarray(cuda.to_gpu(self.ggx))) + + @attr.gpu + @condition.retry(3) + def test_double_backward_cpu_no_cudnn(self): + self.check_double_backward(cuda.to_gpu(self.x), + cuda.to_gpu(self.gy), + cuda.to_gpu(self.ggx), + 'never') + @testing.parameterize(*testing.product({ 'use_cudnn': ['always', 'auto', 'never'], From fcd551f8261050f7847c96895fdbc2e089355f84 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Fri, 18 Aug 2017 13:40:31 +0900 Subject: [PATCH 06/15] Refactor --- chainer/functions/activation/relu.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 8802cad681a0..215017a0dd6f 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -30,6 +30,8 @@ def forward_cpu(self, x): def forward_gpu(self, x): if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous: + # cupy.activation_backward requires the input. + # So, we retain it for backward computation. self.retain_inputs((0,)) self._use_cudnn = True y = cudnn.activation_forward(x[0], _mode) @@ -42,8 +44,6 @@ def forward_gpu(self, x): def backward(self, indexes, gy): y = self.get_retained_outputs()[0] if chainer.should_use_cudnn('==always') and self._use_cudnn: - # The only case to use ReLUGrad3 is compute is done in GPU - # and _use_cudnn is True. x = self.get_retained_inputs()[0] return ReLUGrad3().apply((x, y, gy[0])) else: @@ -64,7 +64,6 @@ class Heaviside(function_node.FunctionNode): def forward(self, inputs): x, = inputs - self.retain_outputs((0,)) return utils.force_array((x > 0).astype(x.dtype)), def backward(self, indexes, gy): @@ -74,18 +73,18 @@ def backward(self, indexes, gy): class ReLUGrad2(function_node.FunctionNode): def forward_cpu(self, inputs): + self.retain_inputs((0,)) b, c = inputs y = (b > 0) * c - self.retain_inputs((0,)) return utils.force_array(y, dtype=y.dtype), def forward_gpu(self, inputs): + self.retain_inputs((0,)) b, c = inputs gx = cuda.elementwise( 'T y, T gy', 'T gx', 'gx = y > 0 ? gy : (T)0', 'relu_bwd')(b, c) - self.retain_inputs((0,)) return gx, def backward(self, indexes, gy): @@ -103,16 +102,16 @@ def backward(self, indexes, gy): class ReLUGrad3(function_node.FunctionNode): def forward_cpu(self, inputs): - a, b, c = inputs - y = (b > 0) * c self.retain_inputs((0,)) + _, b, c = inputs + y = (b > 0) * c return y, def forward_gpu(self, inputs): + self.retain_inputs((1,)) a, b, c = inputs assert chainer.should_use_cudnn('==always') y = cudnn.activation_backward(a, b, c, _mode) - self.retain_inputs((1,)) return y, def backward(self, indexes, gy): From 06b805d81c4e65b8836fabc6f279831604ef3f69 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Fri, 18 Aug 2017 13:47:53 +0900 Subject: [PATCH 07/15] make _use_cudnn class variable --- chainer/functions/activation/relu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 215017a0dd6f..2630f209312b 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -17,6 +17,8 @@ class ReLU(function_node.FunctionNode): """Rectified Linear Unit.""" # TODO(beam2d): Implement in-place version. + _use_cudnn = False + def check_type_forward(self, in_types): type_check.expect( in_types.size() == 1, @@ -25,7 +27,6 @@ def check_type_forward(self, in_types): def forward_cpu(self, x): self.retain_outputs((0,)) - self._use_cudnn = False return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)), def forward_gpu(self, x): @@ -36,7 +37,6 @@ def forward_gpu(self, x): self._use_cudnn = True y = cudnn.activation_forward(x[0], _mode) else: - self._use_cudnn = False y = cuda.cupy.maximum(x[0], 0) self.retain_outputs((0,)) return y, From 184570f316ed13c0c8b45ba822186b6065441304 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Fri, 18 Aug 2017 13:54:02 +0900 Subject: [PATCH 08/15] Simplify backward for efficient backprop --- chainer/functions/activation/relu.py | 30 ++++++---------------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 2630f209312b..d61ae4efdb85 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -50,24 +50,8 @@ def backward(self, indexes, gy): return ReLUGrad2().apply((y, gy[0])) -class Zero(function_node.FunctionNode): - - def forward(self, inputs): - xp = chainer.cuda.get_array_module(*inputs) - return utils.force_array(xp.zeros_like(inputs[0])), - - def backward(self, indexes, gy): - return Zero().apply(gy) - - -class Heaviside(function_node.FunctionNode): - - def forward(self, inputs): - x, = inputs - return utils.force_array((x > 0).astype(x.dtype)), - - def backward(self, indexes, gy): - return Zero().apply(gy) +def _heaviside(x): + return utils.force_array((x.data > 0).astype(x.dtype)) class ReLUGrad2(function_node.FunctionNode): @@ -90,11 +74,10 @@ def forward_gpu(self, inputs): def backward(self, indexes, gy): ret = [] if 0 in indexes: - gb = Zero().apply(gy)[0] - ret.append(gb) + ret.append(None) if 1 in indexes: b = self.get_retained_inputs()[0] - gc = gy[0] * Heaviside().apply((b,))[0] + gc = gy[0] * _heaviside(b) ret.append(gc) return ret @@ -119,11 +102,10 @@ def backward(self, indexes, gy): if 0 in indexes: ret.append(None) if 1 in indexes: - gb = Zero().apply(gy)[0] - ret.append(gb) + ret.append(None) if 2 in indexes: b = self.get_retained_inputs()[0] - gc = gy[0] * Heaviside().apply((b,))[0] + gc = gy[0] * _heaviside(b) ret.append(gc) return ret From 2d7b6e2f3a26c8eb25083b0efbd98d41396c6284 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 09:22:30 +0900 Subject: [PATCH 09/15] Remove an obsolete TOOD comemnt --- chainer/functions/activation/relu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index d61ae4efdb85..38ef05eb474b 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -15,7 +15,6 @@ class ReLU(function_node.FunctionNode): """Rectified Linear Unit.""" - # TODO(beam2d): Implement in-place version. _use_cudnn = False From d24079f24a7f2f7e088c08a4b8025adb8f9df37a Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 09:50:37 +0900 Subject: [PATCH 10/15] Trim unneeded backprops for efficient computation --- chainer/functions/activation/relu.py | 57 +++++++++++++--------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 38ef05eb474b..e30a5513a047 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -44,69 +44,64 @@ def backward(self, indexes, gy): y = self.get_retained_outputs()[0] if chainer.should_use_cudnn('==always') and self._use_cudnn: x = self.get_retained_inputs()[0] - return ReLUGrad3().apply((x, y, gy[0])) + return ReLUGrad3(x, y).apply((gy[0],)) else: - return ReLUGrad2().apply((y, gy[0])) + return ReLUGrad2(y).apply((gy[0],)) def _heaviside(x): - return utils.force_array((x.data > 0).astype(x.dtype)) + return utils.force_array((x > 0).astype(x.dtype)) class ReLUGrad2(function_node.FunctionNode): + def __init__(self, b): + super(ReLUGrad2).__init__() + self.b = b.data def forward_cpu(self, inputs): - self.retain_inputs((0,)) - b, c = inputs - y = (b > 0) * c + y = (self.b > 0) * inputs[0] return utils.force_array(y, dtype=y.dtype), def forward_gpu(self, inputs): - self.retain_inputs((0,)) - b, c = inputs + b = cuda.to_gpu(self.b) gx = cuda.elementwise( 'T y, T gy', 'T gx', 'gx = y > 0 ? gy : (T)0', - 'relu_bwd')(b, c) + 'relu_bwd')(b, inputs[0]) return gx, def backward(self, indexes, gy): - ret = [] if 0 in indexes: - ret.append(None) - if 1 in indexes: - b = self.get_retained_inputs()[0] + xp = cuda.get_array_module(gy[0]) + b = xp.asarray(self.b) gc = gy[0] * _heaviside(b) - ret.append(gc) - return ret + return gc, + else: + return () class ReLUGrad3(function_node.FunctionNode): + def __init__(self, a, b): + self.a = a.data + self.b = b.data def forward_cpu(self, inputs): - self.retain_inputs((0,)) - _, b, c = inputs - y = (b > 0) * c - return y, + return (self.b > 0) * inputs[0], def forward_gpu(self, inputs): - self.retain_inputs((1,)) - a, b, c = inputs + a = cuda.to_gpu(self.a) + b = cuda.to_gpu(self.b) assert chainer.should_use_cudnn('==always') - y = cudnn.activation_backward(a, b, c, _mode) - return y, + return cudnn.activation_backward(a, b, inputs[0], _mode), def backward(self, indexes, gy): - ret = [] if 0 in indexes: - ret.append(None) - if 1 in indexes: - ret.append(None) - if 2 in indexes: - b = self.get_retained_inputs()[0] + xp = cuda.get_array_module(gy[0]) + b = xp.asarray(self.b) gc = gy[0] * _heaviside(b) - ret.append(gc) - return ret + return gc, + else: + return () def relu(x): From 45e79b874217c44c4445b07d55d357949b44c08b Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 09:50:59 +0900 Subject: [PATCH 11/15] Add documents to ReLUGrad2 and ReLUGrad3 --- chainer/functions/activation/relu.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index e30a5513a047..eb117a4afdc5 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -54,6 +54,17 @@ def _heaviside(x): class ReLUGrad2(function_node.FunctionNode): + """Computes the gradient of the ReLU function. + + This function takes 2 variables b and c, and + computes f(b, c) = sign(b) * c with backpropagation + where operations are dones in elementwise manner + and sign(x) = 1 when x > 0 is positive and 0 otherwise. + + As the gradient of f with respect to b is 0, + we do not backpropagate errors toward b for computational efficiency. + """ + def __init__(self, b): super(ReLUGrad2).__init__() self.b = b.data @@ -81,6 +92,17 @@ def backward(self, indexes, gy): class ReLUGrad3(function_node.FunctionNode): + """Computes the gradient of the ReLU function. + + This function takes 3 variables a, b, and c, and + computes f(a, b, c) = sign(b) * c with backpropagation + where operations are dones in elementwise manner + and sign(x) = 1 if x > 0 is positive and 0 otherwise. + + As the gradient of f with respect to a and b are 0, + we do not backpropagate errors toward them for computational efficiency. + """ + def __init__(self, a, b): self.a = a.data self.b = b.data From c0c2b5b99cd1ea7eb04f47b0e5aaa80212ded069 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 10:55:43 +0900 Subject: [PATCH 12/15] Assume that indexes are always non-empty --- chainer/functions/activation/relu.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index eb117a4afdc5..0b0311e15070 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -82,13 +82,10 @@ def forward_gpu(self, inputs): return gx, def backward(self, indexes, gy): - if 0 in indexes: - xp = cuda.get_array_module(gy[0]) - b = xp.asarray(self.b) - gc = gy[0] * _heaviside(b) - return gc, - else: - return () + xp = cuda.get_array_module(gy[0]) + b = xp.asarray(self.b) + gc = gy[0] * _heaviside(b) + return gc, class ReLUGrad3(function_node.FunctionNode): @@ -117,13 +114,10 @@ def forward_gpu(self, inputs): return cudnn.activation_backward(a, b, inputs[0], _mode), def backward(self, indexes, gy): - if 0 in indexes: - xp = cuda.get_array_module(gy[0]) - b = xp.asarray(self.b) - gc = gy[0] * _heaviside(b) - return gc, - else: - return () + xp = cuda.get_array_module(gy[0]) + b = xp.asarray(self.b) + gc = gy[0] * _heaviside(b) + return gc, def relu(x): From 22549b1f6ae99af065d6a6c2b4f1081e5cb667d7 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 10:57:23 +0900 Subject: [PATCH 13/15] force_array is not needed --- chainer/functions/activation/relu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 0b0311e15070..9c8a8c9f16d2 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -50,7 +50,7 @@ def backward(self, indexes, gy): def _heaviside(x): - return utils.force_array((x > 0).astype(x.dtype)) + return (x > 0).astype(x.dtype) class ReLUGrad2(function_node.FunctionNode): From 5e161fed6022f5418751449d0e1184a651a77986 Mon Sep 17 00:00:00 2001 From: Kenta Oono Date: Mon, 21 Aug 2017 10:58:46 +0900 Subject: [PATCH 14/15] No need of the conversion of arrays with asarray --- chainer/functions/activation/relu.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 9c8a8c9f16d2..161b90e856a3 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -83,8 +83,7 @@ def forward_gpu(self, inputs): def backward(self, indexes, gy): xp = cuda.get_array_module(gy[0]) - b = xp.asarray(self.b) - gc = gy[0] * _heaviside(b) + gc = gy[0] * _heaviside(self.b) return gc, @@ -115,8 +114,7 @@ def forward_gpu(self, inputs): def backward(self, indexes, gy): xp = cuda.get_array_module(gy[0]) - b = xp.asarray(self.b) - gc = gy[0] * _heaviside(b) + gc = gy[0] * _heaviside(self.b) return gc, From 5ee9fd50f49cafc29328dc8a7bf2e66b671b6409 Mon Sep 17 00:00:00 2001 From: Kenta OONO Date: Wed, 23 Aug 2017 10:57:31 +0900 Subject: [PATCH 15/15] Fix flake8 --- chainer/functions/activation/relu.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index 161b90e856a3..d830017f64ce 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -82,9 +82,7 @@ def forward_gpu(self, inputs): return gx, def backward(self, indexes, gy): - xp = cuda.get_array_module(gy[0]) - gc = gy[0] * _heaviside(self.b) - return gc, + return gy[0] * _heaviside(self.b), class ReLUGrad3(function_node.FunctionNode): @@ -113,9 +111,7 @@ def forward_gpu(self, inputs): return cudnn.activation_backward(a, b, inputs[0], _mode), def backward(self, indexes, gy): - xp = cuda.get_array_module(gy[0]) - gc = gy[0] * _heaviside(self.b) - return gc, + return gy[0] * _heaviside(self.b), def relu(x):