From dec193bb60edd06cf783aa8cd9c30c8413a677be Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Fri, 28 Jul 2017 15:40:06 +0900 Subject: [PATCH 1/6] Make F.linear use new-style function interface --- chainer/functions/connection/linear.py | 58 ++++++++++--------- chainer/variable.py | 5 ++ .../connection_tests/test_linear.py | 50 +++++++++++++++- 3 files changed, 86 insertions(+), 27 deletions(-) diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index 4c6e40837947..7bdd6f5bb7b5 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -1,14 +1,10 @@ -from chainer import function +from chainer import function_node +from chainer.functions.array import cast +from chainer.functions.math import sum from chainer.utils import type_check -def _as_mat(x): - if x.ndim == 2: - return x - return x.reshape(len(x), -1) - - -class LinearFunction(function.Function): +class LinearFunction(function_node.FunctionNode): def check_type_forward(self, in_types): n_in = in_types.size() @@ -18,9 +14,9 @@ def check_type_forward(self, in_types): type_check.expect( x_type.dtype.kind == 'f', w_type.dtype.kind == 'f', - x_type.ndim >= 2, + x_type.ndim == 2, w_type.ndim == 2, - type_check.prod(x_type.shape[1:]) == w_type.shape[1], + x_type.shape[1] == w_type.shape[1], ) if type_check.eval(n_in) == 3: b_type = in_types[2] @@ -31,7 +27,7 @@ def check_type_forward(self, in_types): ) def forward(self, inputs): - x = _as_mat(inputs[0]) + x = inputs[0] W = inputs[1] if not type_check.same_types(*inputs): @@ -43,20 +39,25 @@ def forward(self, inputs): if len(inputs) == 3: b = inputs[2] y += b + self.retain_inputs((0, 1)) # b is not retained return y, - def backward(self, inputs, grad_outputs): - x = _as_mat(inputs[0]) - W = inputs[1] - gy = grad_outputs[0] + def backward(self, indexes, grad_outputs): + x, W = self.get_retained_inputs() + gy, = grad_outputs - gx = gy.dot(W).astype(x.dtype, copy=False).reshape(inputs[0].shape) - gW = gy.T.dot(x).astype(W.dtype, copy=False) - if len(inputs) == 3: - gb = gy.sum(0) - return gx, gW, gb - else: - return gx, gW + ret = [] + if 0 in indexes: + gx, = LinearFunction().apply((gy, W.T)) + ret.append(cast.cast(gx, x.dtype)) + if 1 in indexes: + gW, = LinearFunction().apply((gy.T, x.T)) + ret.append(cast.cast(gW, W.dtype)) + if 2 in indexes: + gb = sum.sum(gy, axis=0) + ret.append(gb) + + return ret def linear(x, W, b=None): @@ -96,7 +97,12 @@ def linear(x, W, b=None): (3, 5) """ - if b is None: - return LinearFunction()(x, W) - else: - return LinearFunction()(x, W, b) + if x.ndim > 2: + x = x.reshape(len(x), -1) + + args = [x, W] + if b is not None: + args.append(b) + + y, = LinearFunction().apply(args) + return y diff --git a/chainer/variable.py b/chainer/variable.py index 08e94e1d0462..b07445df9310 100644 --- a/chainer/variable.py +++ b/chainer/variable.py @@ -638,6 +638,11 @@ def requires_grad(self): """It indicates that ``grad`` will be set in backward calculation.""" return self._requires_grad + @property + def T(self): + """Transposition of this variable.""" + return chainer.functions.transpose(self) + def to_cpu(self): """Copies the data and gradient arrays to CPU.""" if self.data is None: diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_linear.py b/tests/chainer_tests/functions_tests/connection_tests/test_linear.py index 6b8fa8ed739e..9345ae7b310c 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_linear.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_linear.py @@ -26,6 +26,12 @@ def setUp(self): self.x = numpy.random.uniform(-1, 1, (4, 3)).astype(self.x_dtype) self.gy = numpy.random.uniform(-1, 1, (4, 2)).astype(self.x_dtype) + self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype( + self.x_dtype) + self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype( + self.W_dtype) + self.ggb = numpy.random.uniform(-1, 1, self.b.shape).astype( + self.x_dtype) self.y = self.x.dot(self.W.T) + self.b self.check_forward_options = {} self.check_backward_options = {} @@ -78,7 +84,7 @@ def check_backward(self, x_data, W_data, b_data, y_grad): args = args + (b_data,) gradient_check.check_backward( - linear.LinearFunction(), args, y_grad, + linear.linear, args, y_grad, eps=1e-2, **self.check_backward_options) @condition.retry(3) @@ -101,5 +107,47 @@ def test_backward_gpu_nobias(self): self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), None, cuda.to_gpu(self.gy)) + def check_double_backward(self, x_data, W_data, b_data, y_grad, + x_grad_grad, W_grad_grad, b_grad_grad): + args = x_data, W_data + grad_grads = x_grad_grad, W_grad_grad + if b_data is not None: + args += b_data, + grad_grads += b_grad_grad, + + # non-linear function for testing + def nonlinear(x, W, b=None): + y = linear.linear(x, W, b) + return y * y + + gradient_check.check_double_backward( + nonlinear, args, (y_grad,), grad_grads, + **self.check_backward_options) + + @condition.retry(3) + def test_double_backward_cpu(self): + self.check_double_backward(self.x, self.W, self.b, self.gy, + self.ggx, self.ggW, self.ggb) + + @condition.retry(3) + def test_double_backward_cpu_nobias(self): + self.check_double_backward(self.x, self.W, None, self.gy, + self.ggx, self.ggW, None) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu(self): + self.check_double_backward( + cuda.to_gpu(self.x), cuda.to_gpu(self.W), cuda.to_gpu(self.b), + cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), + cuda.to_gpu(self.ggb)) + + @attr.gpu + @condition.retry(3) + def test_double_backward_gpu_nobias(self): + self.check_double_backward( + cuda.to_gpu(self.x), cuda.to_gpu(self.W), None, + cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), + None) testing.run_module(__name__, __file__) From d7e610e5b40e4d9df41ce27e34630f2d5a0be695 Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Mon, 21 Aug 2017 10:00:21 +0900 Subject: [PATCH 2/6] Use `F.linear` directly in backward for educational purpose --- chainer/functions/connection/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index 7bdd6f5bb7b5..96bf499399c0 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -48,10 +48,10 @@ def backward(self, indexes, grad_outputs): ret = [] if 0 in indexes: - gx, = LinearFunction().apply((gy, W.T)) + gx = linear(gy, W.T) ret.append(cast.cast(gx, x.dtype)) if 1 in indexes: - gW, = LinearFunction().apply((gy.T, x.T)) + gW = linear(gy.T, x.T) ret.append(cast.cast(gW, W.dtype)) if 2 in indexes: gb = sum.sum(gy, axis=0) From d4e5dfb45567bcf5e0f6ec55b250e03ed0b13492 Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Mon, 21 Aug 2017 10:01:38 +0900 Subject: [PATCH 3/6] Use functions directly in backward --- chainer/functions/connection/linear.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index 96bf499399c0..d34d69f01581 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -1,6 +1,5 @@ from chainer import function_node -from chainer.functions.array import cast -from chainer.functions.math import sum +from chainer import functions as F from chainer.utils import type_check @@ -49,12 +48,12 @@ def backward(self, indexes, grad_outputs): ret = [] if 0 in indexes: gx = linear(gy, W.T) - ret.append(cast.cast(gx, x.dtype)) + ret.append(F.cast(gx, x.dtype)) if 1 in indexes: gW = linear(gy.T, x.T) - ret.append(cast.cast(gW, W.dtype)) + ret.append(F.cast(gW, W.dtype)) if 2 in indexes: - gb = sum.sum(gy, axis=0) + gb = F.sum(gy, axis=0) ret.append(gb) return ret From 7d9ac2199f10c42b862ff4455627163ba6c0466f Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Mon, 21 Aug 2017 17:08:55 +0900 Subject: [PATCH 4/6] Fix for new style guide --- chainer/functions/connection/linear.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index d34d69f01581..db06b774c992 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -1,5 +1,5 @@ from chainer import function_node -from chainer import functions as F +import chainer.functions from chainer.utils import type_check @@ -48,12 +48,12 @@ def backward(self, indexes, grad_outputs): ret = [] if 0 in indexes: gx = linear(gy, W.T) - ret.append(F.cast(gx, x.dtype)) + ret.append(chainer.functions.cast(gx, x.dtype)) if 1 in indexes: gW = linear(gy.T, x.T) - ret.append(F.cast(gW, W.dtype)) + ret.append(chainer.functions.cast(gW, W.dtype)) if 2 in indexes: - gb = F.sum(gy, axis=0) + gb = chainer.functions.sum(gy, axis=0) ret.append(gb) return ret @@ -99,9 +99,10 @@ def linear(x, W, b=None): if x.ndim > 2: x = x.reshape(len(x), -1) - args = [x, W] - if b is not None: - args.append(b) + if b is None: + args = x, W + else: + args = x, W, b y, = LinearFunction().apply(args) return y From ccf23c3bff1272feb45d75520547eba9b0bbbfb0 Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Tue, 22 Aug 2017 17:32:54 +0900 Subject: [PATCH 5/6] Fix timer test for new linear --- .../function_hooks_tests/test_timer.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/chainer_tests/function_hooks_tests/test_timer.py b/tests/chainer_tests/function_hooks_tests/test_timer.py index e151d1ef854f..b5ef82a58d38 100644 --- a/tests/chainer_tests/function_hooks_tests/test_timer.py +++ b/tests/chainer_tests/function_hooks_tests/test_timer.py @@ -8,22 +8,34 @@ from chainer import cuda from chainer import function_hooks from chainer import functions -from chainer.functions.connection import linear -from chainer import links +from chainer.functions.math import basic_math from chainer import testing from chainer.testing import attr def check_history(self, t, function_type, return_type): - self.assertIsInstance(t[0].function, function_type) + func = getattr(t[0], 'function', t[0]) + self.assertIsInstance(func, function_type) self.assertIsInstance(t[1], return_type) +class SimpleLink(chainer.Link): + def __init__(self): + super(SimpleLink, self).__init__() + with self.init_scope(): + init_w = numpy.random.uniform(-1, 1, (3, 5)).astype( + numpy.float32) + self.w = chainer.Parameter(init_w) + + def __call__(self, x): + return self.w * x + + class TestTimerHookToLink(unittest.TestCase): def setUp(self): self.h = function_hooks.TimerHook() - self.l = links.Linear(5, 5) + self.l = SimpleLink() self.x = numpy.random.uniform(-0.1, 0.1, (3, 5)).astype(numpy.float32) self.gy = numpy.random.uniform(-0.1, 0.1, (3, 5)).astype(numpy.float32) @@ -34,8 +46,7 @@ def check_forward(self, x): with self.h: self.l(chainer.Variable(x)) self.assertEqual(1, len(self.h.call_history)) - check_history(self, self.h.call_history[0], - linear.LinearFunction, float) + check_history(self, self.h.call_history[0], basic_math.Mul, float) def test_forward_cpu(self): self.check_forward(self.x) @@ -56,7 +67,7 @@ def check_backward(self, x, gy): for entry in self.h.call_history: if entry[0].label == '_ + _': continue - check_history(self, entry, linear.LinearFunction, float) + check_history(self, entry, basic_math.Mul, float) def test_backward_cpu(self): self.check_backward(self.x, self.gy) From 34355b8ec08fa09d93a954a05d92b48d6bb0f2ae Mon Sep 17 00:00:00 2001 From: Seiya Tokui Date: Tue, 22 Aug 2017 17:34:22 +0900 Subject: [PATCH 6/6] Fix style --- .../functions_tests/connection_tests/test_linear.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_linear.py b/tests/chainer_tests/functions_tests/connection_tests/test_linear.py index 9345ae7b310c..173c81e97a86 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_linear.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_linear.py @@ -150,4 +150,5 @@ def test_double_backward_gpu_nobias(self): cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), None) + testing.run_module(__name__, __file__)