From 0c125a27b4903a607898fac8f38935ddf8f6d699 Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 29 Aug 2017 23:53:58 +0900 Subject: [PATCH 1/7] Allows list, tuple and None in to_cpu and to_gpu --- chainer/cuda.py | 59 +++++++++++++++++++++++---- tests/chainer_tests/test_cuda.py | 68 ++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 8 deletions(-) diff --git a/chainer/cuda.py b/chainer/cuda.py index 327ecf18da3f..cc739f81590b 100644 --- a/chainer/cuda.py +++ b/chainer/cuda.py @@ -237,16 +237,17 @@ def to_gpu(array, device=None, stream=None): """Copies the given CPU array to the specified device. Args: - array: Array to be sent to GPU. + array (numpy.ndarray, cupy.ndarray, list or tuple): Array or arrays to + be sent to GPU. device: Device specifier. stream (~cupy.cuda.Stream): *(deprecated since v3.0.0)* CUDA stream. If not ``None``, the copy runs asynchronously. Returns: - cupy.ndarray: Array on GPU. + cupy.ndarray, list or tuple: Array or arrays on GPU. - If ``array`` is already on the GPU device specified by ``device``, - this function just returns ``array`` without performing any copy. + If some of the arrays are already on GPU, then this function just + returns those arrays without performing any copy. """ if stream is not None: @@ -255,6 +256,27 @@ def to_gpu(array, device=None, stream=None): 'Please remove it.', DeprecationWarning) check_cuda_available() + with Device(device) as device_: + if isinstance(array, (list, tuple)): + d = {} + ret = [] + for arr in array: + if arr is None: + ret.append(None) + else: + arr2 = d.get(id(arr)) + if arr2 is None: + arr2 = _array_to_gpu(arr, device_, stream) + d[id(arr)] = arr2 + ret.append(arr2) + return type(array)(ret) + else: + return _array_to_gpu(array, device_, stream) + + +def _array_to_gpu(array, device, stream): + if array is None: + return None if isinstance(array, (numpy.number, numpy.bool_)): array = numpy.asarray(array) if not isinstance(array, (cupy.ndarray, numpy.ndarray)): @@ -303,16 +325,37 @@ def to_cpu(array, stream=None): """Copies the given GPU array to host CPU. Args: - array: Array to be sent to CPU. + array (numpy.ndarray, cupy.ndarray, list or tuple): Array or arrays to + be sent to CPU. stream (cupy.cuda.Stream): CUDA stream. Returns: - numpy.ndarray: Array on CPU. + numpy.ndarray, list or tuple: Array on CPU. - If given ``array`` is already on CPU, then this function just returns - ``array`` without performing any copy. + If some of the arrays are already on CPU, then this function just + returns those arrays without performing any copy. """ + if isinstance(array, (list, tuple)): + d = {} + ret = [] + for arr in array: + if arr is None: + ret.append(None) + else: + arr2 = d.get(id(arr)) + if arr2 is None: + arr2 = _array_to_cpu(arr, stream) + d[id(arr)] = arr2 + ret.append(arr2) + return type(array)(ret) + else: + return _array_to_cpu(array, stream) + + +def _array_to_cpu(array, stream): + if array is None: + return None if isinstance(array, ndarray): check_cuda_available() with get_device_from_array(array): diff --git a/tests/chainer_tests/test_cuda.py b/tests/chainer_tests/test_cuda.py index dd8db9a27001..8ca31ac8fdd9 100644 --- a/tests/chainer_tests/test_cuda.py +++ b/tests/chainer_tests/test_cuda.py @@ -194,6 +194,40 @@ def test_cupy_array_async2(self): self.assertIsInstance(y, numpy.ndarray) cuda.cupy.testing.assert_array_equal(self.x, y) + def test_single_none(self): + assert cuda.to_cpu(None) is None + + def _check_list_tuple(self, typ): + assert typ in (list, tuple) + a = numpy.random.uniform(-1, 1, (0,)) + b = numpy.random.uniform(-1, 1, (2, 3)) + c = cuda.cupy.random.uniform(-1, 1, (0,)) + d = cuda.cupy.random.uniform(-1, 1, (2, 2)) + xs = typ([a, b, c, d, None, a, b, None, c, d]) + xs_cpu = cuda.to_cpu(xs) + + assert isinstance(xs_cpu, typ) + assert len(xs) == len(xs_cpu) + for i in (0, 1, 2, 3, 5, 6, 8, 9): + assert isinstance(xs_cpu[i], numpy.ndarray) + cuda.cupy.testing.assert_array_equal(xs[i], xs_cpu[i]) + assert xs_cpu[0] is a + assert xs_cpu[1] is b + assert xs_cpu[2] is xs_cpu[8] + assert xs_cpu[3] is xs_cpu[9] + assert xs_cpu[4] is None + assert xs_cpu[5] is a + assert xs_cpu[6] is b + assert xs_cpu[7] is None + + @attr.gpu + def test_list(self): + self._check_list_tuple(list) + + @attr.gpu + def test_tuple(self): + self._check_list_tuple(tuple) + def test_variable(self): x = chainer.Variable(self.x) with self.assertRaises(TypeError): @@ -337,6 +371,40 @@ def test_cupy_array_async3(self): self.assertIsNot(x, y) # Do copy cuda.cupy.testing.assert_array_equal(x, y) + def test_single_none(self): + assert cuda.to_gpu(None) is None + + def _check_list_tuple(self, typ): + assert typ in (list, tuple) + a = numpy.random.uniform(-1, 1, (0,)) + b = numpy.random.uniform(-1, 1, (2, 3)) + c = cuda.cupy.random.uniform(-1, 1, (0,)) + d = cuda.cupy.random.uniform(-1, 1, (2, 2)) + xs = typ([a, b, c, d, None, a, b, None, c, d]) + xs_gpu = cuda.to_gpu(xs) + + assert isinstance(xs_gpu, typ) + assert len(xs) == len(xs_gpu) + for i in (0, 1, 2, 3, 5, 6, 8, 9): + assert isinstance(xs_gpu[i], cuda.cupy.ndarray) + cuda.cupy.testing.assert_array_equal(xs[i], xs_gpu[i]) + assert xs_gpu[0] is xs_gpu[5] + assert xs_gpu[1] is xs_gpu[6] + assert xs_gpu[2] is c + assert xs_gpu[3] is d + assert xs_gpu[4] is None + assert xs_gpu[7] is None + assert xs_gpu[8] is c + assert xs_gpu[9] is d + + @attr.gpu + def test_list(self): + self._check_list_tuple(list) + + @attr.gpu + def test_tuple(self): + self._check_list_tuple(tuple) + @attr.gpu def test_variable_gpu(self): x = chainer.Variable(self.x) From ad247e5fc7940d7faec8ae0489b8b2a403a24d1d Mon Sep 17 00:00:00 2001 From: niboshi Date: Mon, 13 Nov 2017 15:51:07 +0900 Subject: [PATCH 2/7] Require attr.gpu for to_gpu(None ) --- tests/chainer_tests/test_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/chainer_tests/test_cuda.py b/tests/chainer_tests/test_cuda.py index 8ca31ac8fdd9..30ddf49d78e9 100644 --- a/tests/chainer_tests/test_cuda.py +++ b/tests/chainer_tests/test_cuda.py @@ -371,6 +371,7 @@ def test_cupy_array_async3(self): self.assertIsNot(x, y) # Do copy cuda.cupy.testing.assert_array_equal(x, y) + @attr.gpu def test_single_none(self): assert cuda.to_gpu(None) is None From f63970e694740363221d9e78c3aa91930b36794b Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 14 Nov 2017 11:47:05 +0900 Subject: [PATCH 3/7] Refactor unit tests for various backend configuration --- chainer/testing/backend.py | 126 ++++++++++ chainer/testing/parameterized.py | 30 ++- .../connection_tests/test_deconvolution_2d.py | 237 ++++++++++-------- 3 files changed, 283 insertions(+), 110 deletions(-) create mode 100644 chainer/testing/backend.py diff --git a/chainer/testing/backend.py b/chainer/testing/backend.py new file mode 100644 index 000000000000..8026908d17b4 --- /dev/null +++ b/chainer/testing/backend.py @@ -0,0 +1,126 @@ +import functools +import unittest + +import numpy + +import chainer +from chainer import cuda +from chainer.testing import attr + + +class BackendConfig(object): + + _props = [ + ('use_cuda', False), + ('use_cudnn', 'never'), + ('cudnn_deterministic', False), + ('autotune', False), + ] + + def __init__(self, params): + assert isinstance(params, dict) + self._contexts = [] + + # Default values + for k, v in self._props: + setattr(self, k, v) + # Specified values + for k, v in params.items(): + if not hasattr(self, k): + raise ValueError('Parameter {} is not defined'.format(k)) + setattr(self, k, v) + + @property + def xp(self): + if self.use_cuda: + return cuda.cupy + else: + return numpy + + def __enter__(self): + self._contexts = [ + chainer.using_config( + 'use_cudnn', self.use_cudnn), + chainer.using_config( + 'cudnn_deterministic', self.cudnn_deterministic), + chainer.using_config( + 'autotune', self.autotune), + ] + for c in self._contexts: + c.__enter__() + return self + + def __exit__(self, typ, value, traceback): + for c in reversed(self._contexts): + c.__exit__(typ, value, traceback) + + def __repr__(self): + lst = [] + for k, _ in self._props: + lst.append('{}={!r}'.format(k, getattr(self, k))) + return ''.format(' '.join(lst)) + + def get_func_str(self): + """Returns a string that can be used in method name""" + lst = [] + for k, _ in self._props: + val = getattr(self, k) + if val is True: + val = 'true' + elif val is False: + val = 'false' + else: + val = str(val) + lst.append('{}_{}'.format(k, val)) + return '__'.join(lst) + + def get_pytest_marks(self): + marks = [] + if self.use_cuda: + marks.append(attr.gpu) + if self.use_cudnn != 'never': + marks.append(attr.cudnn) + + assert all(callable(_) for _ in marks) + return marks + + +def _wrap_backend_test_method(impl, param, method_name): + backend_config = BackendConfig(param) + marks = backend_config.get_pytest_marks() + new_method_name = '{}__{}'.format( + method_name, backend_config.get_func_str()) + + @functools.wraps(impl) + def func(self, *args, **kwargs): + impl(self, backend_config, *args, **kwargs) + + func.__name__ = new_method_name + + # Apply test marks + for mark in marks: + func = mark(func) + + return func, new_method_name + + +def inject_backend_tests(method_names, params): + assert isinstance(method_names, list) + assert isinstance(params, list) + assert all(isinstance(_, dict) for _ in params) + + def wrap(case): + assert issubclass(case, unittest.TestCase) + for method_name in method_names: + impl = getattr(case, method_name) + delattr(case, method_name) + for i_param, param in enumerate(params): + new_impl, new_method_name = _wrap_backend_test_method( + impl, param, method_name) + if hasattr(case, new_method_name): + raise RuntimeError( + 'Test fixture already exists: {}'.format( + new_method_name)) + setattr(case, new_method_name, new_impl) + return case + return wrap diff --git a/chainer/testing/parameterized.py b/chainer/testing/parameterized.py index 1865a0af93e8..f6c7b01a78c0 100644 --- a/chainer/testing/parameterized.py +++ b/chainer/testing/parameterized.py @@ -81,10 +81,32 @@ def f(klass): def product(parameter): - keys = sorted(parameter) - values = [parameter[key] for key in keys] - values_product = itertools.product(*values) - return [dict(zip(keys, vals)) for vals in values_product] + if isinstance(parameter, dict): + keys = sorted(parameter) + values = [parameter[key] for key in keys] + values_product = itertools.product(*values) + return [dict(zip(keys, vals)) for vals in values_product] + + elif isinstance(parameter, list): + # list of lists of dicts + if not all(isinstance(_, list) for _ in parameter): + raise TypeError('parameter must be list of lists of dicts') + if not all(isinstance(_, dict) for l in parameter for _ in l): + raise TypeError('parameter must be list of lists of dicts') + + product = list(itertools.product(*parameter)) + lst = [] + for dict_lst in product: + a = {} + for d in dict_lst: + a.update(d) + lst.append(a) + return lst + + else: + raise TypeError( + 'parameter must be either dict or list. Actual: {}'.format( + type(parameter))) def product_dict(*parameters): diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py index 592284257650..6b1e3bb8d560 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py @@ -9,6 +9,7 @@ from chainer import gradient_check from chainer import testing from chainer.testing import attr +from chainer.testing import backend from chainer.testing import condition from chainer.testing import parameterize from chainer.utils import conv @@ -20,27 +21,44 @@ def _pair(x): return x, x -@parameterize(*(testing.product({ - 'c_contiguous': [True], - 'test_outsize': [True, False], - 'nobias': [True], - 'stride': [1, 2], - 'use_cudnn': ['always'], - 'cudnn_deterministic': [True, False], - 'x_dtype': [numpy.float32], - 'W_dtype': [numpy.float32], - 'autotune': [True, False], -}) + testing.product({ - 'c_contiguous': [False], - 'test_outsize': [True], - 'nobias': [False], - 'stride': [1, 2], - 'use_cudnn': ['always', 'never'], - 'cudnn_deterministic': [False], - 'x_dtype': [numpy.float16, numpy.float32, numpy.float64], - 'W_dtype': [numpy.float16, numpy.float32, numpy.float64], - 'autotune': [False], -}))) +@parameterize(*(testing.product([ + testing.product({ + 'c_contiguous': [True], + 'test_outsize': [True, False], + 'nobias': [True], + 'stride': [1, 2], + 'x_dtype': [numpy.float32], + 'W_dtype': [numpy.float32], + }) + + testing.product({ + 'c_contiguous': [False], + 'test_outsize': [True], + 'nobias': [False], + 'stride': [1, 2], + 'x_dtype': [numpy.float16, numpy.float32, numpy.float64], + 'W_dtype': [numpy.float16, numpy.float32, numpy.float64], + }), +]))) +@backend.inject_backend_tests( + ['test_forward', 'test_backward', 'test_double_backward'], + # CPU tests + [{ + 'use_cuda': False, + }] + # GPU tests + + testing.product([ + [{'use_cuda': True}], + + # Without cuDNN + testing.product({ + 'use_cudnn': ['never'], + }) + # With cuDNN + + testing.product({ + 'use_cudnn': ['always'], + 'cudnn_deterministic': [True, False], + 'autotune': [True, False], + })])) class TestDeconvolution2DFunction(unittest.TestCase): in_channels = 3 @@ -52,11 +70,12 @@ def setUp(self): kh, kw = _pair(self.ksize) sh, sw = _pair(self.stride) ph, pw = _pair(self.pad) - self.W = numpy.random.normal( + + W = numpy.random.normal( 0, numpy.sqrt(1. / (kh * kw * self.in_channels)), (self.in_channels, self.out_channels, kh, kw) ).astype(self.W_dtype) - self.b = None if self.nobias else numpy.random.uniform( + b = None if self.nobias else numpy.random.uniform( -1, 1, self.out_channels).astype(self.x_dtype) N = 2 @@ -64,17 +83,21 @@ def setUp(self): outh = conv.get_deconv_outsize(inh, kh, sh, ph) outw = conv.get_deconv_outsize(inw, kw, sw, pw) self.outsize = (outh, outw) if self.test_outsize else None - self.x = numpy.random.uniform( + x = numpy.random.uniform( -1, 1, (N, self.in_channels, inh, inw)).astype(self.x_dtype) - self.gy = numpy.random.uniform( + gy = numpy.random.uniform( -1, 1, (N, self.out_channels, outh, outw)).astype(self.x_dtype) - self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype( + ggx = numpy.random.uniform(-1, 1, x.shape).astype( self.x_dtype) - self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype( + ggW = numpy.random.uniform(-1, 1, W.shape).astype( self.W_dtype) - self.ggb = None if self.nobias else numpy.random.uniform( - -1, 1, self.b.shape).astype(self.x_dtype) + ggb = None if self.nobias else numpy.random.uniform( + -1, 1, b.shape).astype(self.x_dtype) + + self.inputs = [x, W, b] + self.grad_outputs = [gy] + self.grad_grad_inputs = [ggx, ggW, ggb] self.test_forward_options = {} self.check_backward_options = {'dtype': numpy.float64} @@ -87,54 +110,67 @@ def setUp(self): self.check_backward_options.update(atol=5e-4, rtol=5e-3) self.check_double_backward_options.update(atol=5e-3, rtol=5e-2) - @attr.gpu - def test_forward_consistency(self): - x_cpu = chainer.Variable(self.x) - W_cpu = chainer.Variable(self.W) - b_cpu = None if self.nobias else chainer.Variable(self.b) - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - y_cpu = F.deconvolution_2d( - x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, + def forward_cpu(self, inputs): + x, W, b = inputs + x_cpu = chainer.Variable(x) + W_cpu = chainer.Variable(W) + b_cpu = None if b is None else chainer.Variable(b) + y_cpu = F.deconvolution_2d( + x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, + outsize=self.outsize) + return y_cpu, + + def check_forward(self, inputs, backend_config): + y_expected, = self.forward_cpu(inputs) + + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + + x, W, b = inputs + x = chainer.Variable(x) + W = chainer.Variable(W) + b = None if b is None else chainer.Variable(b) + + with backend_config: + y_actual = F.deconvolution_2d( + x, W, b, stride=self.stride, pad=self.pad, outsize=self.outsize) - x_gpu = chainer.Variable(cuda.to_gpu(self.x)) - W_gpu = chainer.Variable(cuda.to_gpu(self.W)) - b_gpu = None if self.nobias else chainer.Variable( - cuda.to_gpu(self.b)) - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - with chainer.using_config('autotune', self.autotune): - y_gpu = F.deconvolution_2d( - x_gpu, W_gpu, b_gpu, stride=self.stride, pad=self.pad, - outsize=self.outsize) - - self.assertEqual(y_cpu.data.dtype, self.x_dtype) - self.assertEqual(y_gpu.data.dtype, self.x_dtype) + assert y_expected.data.dtype == self.x_dtype + assert y_actual.data.dtype == self.x_dtype testing.assert_allclose( - y_cpu.data, y_gpu.data.get(), **self.test_forward_options) + y_expected.data, y_actual.data.get(), **self.test_forward_options) @attr.gpu - def test_forward_consistency_im2col(self): - self.use_cudnn = 'never' - self.test_forward_consistency() + def test_forward(self, backend_config): + # Forward test does not currently target CPU backend. + # It only tests for consistency between GPU and CPU computation. + if not backend_config.use_cuda: + return + self.check_forward(self.inputs, backend_config) + + def check_backward(self, inputs, grad_outputs, backend_config): - def check_backward(self, x_data, W_data, b_data, y_grad): - xp = cuda.get_array_module(x_data) + xp = backend_config.xp + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + grad_outputs = cuda.to_gpu(grad_outputs) + + x_data, W_data, b_data = inputs + y_grad, = grad_outputs if not self.c_contiguous: x_data = xp.asfortranarray(x_data) W_data = xp.asfortranarray(W_data) y_grad = xp.asfortranarray(y_grad) - self.assertFalse(x_data.flags.c_contiguous) - self.assertFalse(W_data.flags.c_contiguous) - self.assertFalse(y_grad.flags.c_contiguous) + assert not x_data.flags.c_contiguous + assert not W_data.flags.c_contiguous + assert not y_grad.flags.c_contiguous if b_data is not None: - b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b = xp.empty((len(b_data) * 2,), dtype=b_data.dtype) b[::2] = b_data b_data = b[::2] - self.assertFalse(b_data.flags.c_contiguous) + assert not b_data.flags.c_contiguous args = (x_data, W_data) if b_data is not None: @@ -144,27 +180,26 @@ def f(*args): return F.deconvolution_2d( *args, stride=self.stride, pad=self.pad, outsize=self.outsize) - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - with chainer.using_config('autotune', self.autotune): - gradient_check.check_backward( - f, args, y_grad, **self.check_backward_options) + with backend_config: + gradient_check.check_backward( + f, args, y_grad, **self.check_backward_options) @condition.retry(10) - def test_backward_cpu(self): - self.check_backward(self.x, self.W, self.b, self.gy) + def test_backward(self, backend_config): + self.check_backward(self.inputs, self.grad_outputs, backend_config) - @attr.gpu - @condition.retry(10) - def test_backward_gpu(self): - b = None if self.b is None else cuda.to_gpu(self.b) - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), - b, cuda.to_gpu(self.gy)) + def check_double_backward( + self, inputs, grad_outputs, grad_grad_inputs, backend_config): + xp = backend_config.xp + + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + grad_outputs = cuda.to_gpu(grad_outputs) + grad_grad_inputs = cuda.to_gpu(grad_grad_inputs) - def check_double_backward(self, x_data, W_data, b_data, y_grad, - x_grad_grad, W_grad_grad, b_grad_grad): - xp = cuda.get_array_module(x_data) + x_data, W_data, b_data = inputs + y_grad, = grad_outputs + x_grad_grad, W_grad_grad, b_grad_grad = grad_grad_inputs if not self.c_contiguous: x_data = xp.asfortranarray(x_data) @@ -172,21 +207,21 @@ def check_double_backward(self, x_data, W_data, b_data, y_grad, y_grad = xp.asfortranarray(y_grad) x_grad_grad = xp.asfortranarray(x_grad_grad) W_grad_grad = xp.asfortranarray(W_grad_grad) - self.assertFalse(x_data.flags.c_contiguous) - self.assertFalse(W_data.flags.c_contiguous) - self.assertFalse(y_grad.flags.c_contiguous) - self.assertFalse(x_grad_grad.flags.c_contiguous) - self.assertFalse(W_grad_grad.flags.c_contiguous) + assert not x_data.flags.c_contiguous + assert not W_data.flags.c_contiguous + assert not y_grad.flags.c_contiguous + assert not x_grad_grad.flags.c_contiguous + assert not W_grad_grad.flags.c_contiguous if b_data is not None: - b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b = xp.empty((len(b_data) * 2,), dtype=b_data.dtype) b[::2] = b_data b_data = b[::2] - self.assertFalse(b_data.flags.c_contiguous) + assert not b_data.flags.c_contiguous - ggb = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + ggb = xp.empty((len(b_data) * 2,), dtype=b_grad_grad.dtype) ggb[::2] = b_grad_grad b_grad_grad = ggb[::2] - self.assertFalse(b_grad_grad.flags.c_contiguous) + assert not b_grad_grad.flags.c_contiguous args = (x_data, W_data) grad_grads = (x_grad_grad, W_grad_grad) @@ -199,26 +234,16 @@ def f(*args): *args, stride=self.stride, pad=self.pad, outsize=self.outsize) return y * y # make the function nonlinear - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - gradient_check.check_double_backward( - f, args, y_grad, grad_grads, - **self.check_double_backward_options) + with backend_config: + gradient_check.check_double_backward( + f, args, y_grad, grad_grads, + **self.check_double_backward_options) @condition.retry(10) - def test_double_backward_cpu(self): - self.check_double_backward(self.x, self.W, self.b, self.gy, - self.ggx, self.ggW, self.ggb) - - @attr.gpu - @condition.retry(10) - def test_double_backward_gpu(self): + def test_double_backward(self, backend_config): self.check_double_backward( - cuda.to_gpu(self.x), cuda.to_gpu(self.W), - None if self.b is None else cuda.to_gpu(self.b), - cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), - None if self.ggb is None else cuda.to_gpu(self.ggb)) + self.inputs, self.grad_outputs, self.grad_grad_inputs, + backend_config) @testing.parameterize(*testing.product({ From 7c92369b1709907c9ce6bc574319d894914313d5 Mon Sep 17 00:00:00 2001 From: niboshi Date: Mon, 13 Nov 2017 11:12:26 +0900 Subject: [PATCH 4/7] Refactor (de)convolution_2d --- chainer/function_node.py | 6 + .../functions/connection/convolution_2d.py | 181 ++++++++-------- .../functions/connection/deconvolution_2d.py | 193 +++++++++--------- 3 files changed, 196 insertions(+), 184 deletions(-) diff --git a/chainer/function_node.py b/chainer/function_node.py index 4431c7c11046..9227ef10ba5c 100644 --- a/chainer/function_node.py +++ b/chainer/function_node.py @@ -284,6 +284,12 @@ def apply(self, inputs): return ret def _check_data_type_forward(self, in_data): + xp = cuda.get_array_module(*in_data) + if not all([isinstance(_, xp.ndarray) for _ in in_data]): + raise ValueError('numpy and cupy must not be used together\n' + '{}' + .format(', '.join(str(type(_)) for _ in in_data))) + in_type = type_check.get_light_types(in_data) try: with type_check.light_mode: diff --git a/chainer/functions/connection/convolution_2d.py b/chainer/functions/connection/convolution_2d.py index ea69d129f02e..820459881bd2 100644 --- a/chainer/functions/connection/convolution_2d.py +++ b/chainer/functions/connection/convolution_2d.py @@ -98,20 +98,27 @@ def check_type_forward(self, in_types): b_type.shape[0] == w_type.shape[0], ) + def _get_out_size(self, inputs): + x, W = inputs[:2] + _, _, kh, kw = W.shape + _, _, h, w = x.shape + out_h = conv.get_conv_outsize( + h, kh, self.sy, self.ph, cover_all=self.cover_all, d=self.dy) + if out_h <= 0: + raise RuntimeError('Height in the output should be positive.') + out_w = conv.get_conv_outsize( + w, kw, self.sx, self.pw, cover_all=self.cover_all, d=self.dx) + if out_w <= 0: + raise RuntimeError('Width in the output should be positive.') + return out_h, out_w + def forward_cpu(self, inputs): self.retain_inputs((0, 1)) # retain only x and W - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - if not all([isinstance(i, numpy.ndarray) for i in inputs]): - if b is not None: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}, type(b): {2}' - .format(type(W), type(x), type(b))) - else: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}' - .format(type(W), type(x))) + if len(inputs) == 2: + (x, W), b = inputs, None + else: + x, W, b = inputs kh, kw = W.shape[2:] col = conv.im2col_cpu( @@ -125,82 +132,24 @@ def forward_cpu(self, inputs): def forward_gpu(self, inputs): self.retain_inputs((0, 1)) # retain only x and W - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - if not all([isinstance(i, cuda.ndarray) for i in inputs]): - if b is not None: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}, type(b): {2}' - .format(type(W), type(x), type(b))) - else: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}' - .format(type(W), type(x))) + if len(inputs) == 2: + (x, W), b = inputs, None + else: + x, W, b = inputs out_c, _, kh, kw = W.shape - n, c, h, w = x.shape - - out_h = conv.get_conv_outsize(h, kh, self.sy, self.ph, - cover_all=self.cover_all, d=self.dy) - assert out_h > 0, 'Height in the output should be positive.' - out_w = conv.get_conv_outsize(w, kw, self.sx, self.pw, - cover_all=self.cover_all, d=self.dx) - assert out_w > 0, 'Width in the output should be positive.' + n, _, h, w = x.shape + out_h, out_w = self._get_out_size(inputs) y = cuda.cupy.empty((n, out_c, out_h, out_w), dtype=x.dtype) + if (not self.cover_all and chainer.should_use_cudnn('>=auto') and x.dtype == W.dtype and ((self.dy == 1 and self.dx == 1) or _cudnn_version >= 6000)): - x = cuda.cupy.ascontiguousarray(x) - W = cuda.cupy.ascontiguousarray(W) - if b is not None: - b = cuda.cupy.ascontiguousarray(b) - - use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype) - handle = cudnn.get_handle() - x_desc = cudnn.create_tensor_descriptor(x) - y_desc = cudnn.create_tensor_descriptor(y) + # cuDNN implementation + return self._forward_cudnn(x, W, b, y) - filter_desc = cudnn.create_filter_descriptor(W) - conv_param = ((self.ph, self.pw), (self.sy, self.sx), x.dtype) - conv_desc = cudnn.create_convolution_descriptor( - *conv_param, dilation=(self.dy, self.dx), - use_tensor_core=use_tensor_core) - if b is not None: - bias_desc = cudnn.create_tensor_descriptor( - b[None, :, None, None]) - workspace_size = cuda.get_max_workspace_size() - workspace = cuda.cupy.empty((workspace_size,), dtype='b') - if configuration.config.autotune and _cudnn_version >= 5000: - algo = get_algorithm_fwd( - x, W, y, conv_param, handle, x_desc, filter_desc, - conv_desc, y_desc, workspace) - else: - algo = libcudnn.getConvolutionForwardAlgorithm( - handle, x_desc.value, filter_desc.value, - conv_desc.value, y_desc.value, _fwd_pref, workspace_size) - - if use_tensor_core: - # Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM - # supports Tensor-Core in cuDNN7. - algo = libcudnn.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM # NOQA - - oz_dtype = 'd' if x.dtype == 'd' else 'f' - one = numpy.array(1, dtype=oz_dtype).ctypes - zero = numpy.array(0, dtype=oz_dtype).ctypes - libcudnn.convolutionForward( - handle, one.data, x_desc.value, x.data.ptr, - filter_desc.value, W.data.ptr, conv_desc.value, - algo, workspace.data.ptr, workspace_size, zero.data, - y_desc.value, y.data.ptr) - - # TODO(beam2d): Support unshared bias - if b is not None: - cudnn.add_tensor( - handle, one.data, bias_desc.value, b.data.ptr, - one.data, y_desc.value, y.data.ptr) else: # Implementation using im2col col = conv.im2col_gpu( @@ -215,6 +164,59 @@ def forward_gpu(self, inputs): return y, + def _forward_cudnn(self, x, W, b, y): + x = cuda.cupy.ascontiguousarray(x) + W = cuda.cupy.ascontiguousarray(W) + if b is not None: + b = cuda.cupy.ascontiguousarray(b) + + use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype) + + handle = cudnn.get_handle() + x_desc = cudnn.create_tensor_descriptor(x) + y_desc = cudnn.create_tensor_descriptor(y) + + filter_desc = cudnn.create_filter_descriptor(W) + conv_param = ((self.ph, self.pw), (self.sy, self.sx), x.dtype) + conv_desc = cudnn.create_convolution_descriptor( + *conv_param, dilation=(self.dy, self.dx), + use_tensor_core=use_tensor_core) + if b is not None: + bias_desc = cudnn.create_tensor_descriptor( + b[None, :, None, None]) + workspace_size = cuda.get_max_workspace_size() + workspace = cuda.cupy.empty((workspace_size,), dtype='b') + if configuration.config.autotune and _cudnn_version >= 5000: + algo = get_algorithm_fwd( + x, W, y, conv_param, handle, x_desc, filter_desc, + conv_desc, y_desc, workspace) + else: + algo = libcudnn.getConvolutionForwardAlgorithm( + handle, x_desc.value, filter_desc.value, + conv_desc.value, y_desc.value, _fwd_pref, workspace_size) + + if use_tensor_core: + # Only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM + # supports Tensor-Core in cuDNN7. + algo = libcudnn.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM # NOQA + + oz_dtype = 'd' if x.dtype == 'd' else 'f' + one = numpy.array(1, dtype=oz_dtype).ctypes + zero = numpy.array(0, dtype=oz_dtype).ctypes + libcudnn.convolutionForward( + handle, one.data, x_desc.value, x.data.ptr, + filter_desc.value, W.data.ptr, conv_desc.value, + algo, workspace.data.ptr, workspace_size, zero.data, + y_desc.value, y.data.ptr) + + # TODO(beam2d): Support unshared bias + if b is not None: + cudnn.add_tensor( + handle, one.data, bias_desc.value, b.data.ptr, + one.data, y_desc.value, y.data.ptr) + + return y, + def backward(self, indexes, grad_outputs): x, W = self.get_retained_inputs() gy, = grad_outputs @@ -271,19 +273,30 @@ def forward_cpu(self, inputs): def forward_gpu(self, inputs): self.retain_inputs((0, 1)) x, gy = inputs - _, out_c, out_h, out_w = gy.shape - n, c, h, w = x.shape - if (self.cover_all or not chainer.should_use_cudnn('>=auto') or - x.dtype != self.W_dtype or - ((self.dy > 1 or self.dx > 1) and _cudnn_version < 6000)): + if (not self.cover_all and chainer.should_use_cudnn('>=auto') and + x.dtype == self.W_dtype and + ((self.dy == 1 and self.dx == 1) or _cudnn_version >= 6000)): + + # cuDNN implementation + return self._forward_cudnn(x, gy) + + else: + # Implementation using im2col + _, out_c, out_h, out_w = gy.shape + n, c, h, w = x.shape + col = conv.im2col_gpu( x, self.kh, self.kw, self.sy, self.sx, self.ph, self.pw, cover_all=self.cover_all, dy=self.dy, dx=self.dx) gW = cuda.cupy.tensordot( gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype, copy=False) - return gW, + return gW, + + def _forward_cudnn(self, x, gy): + _, out_c, out_h, out_w = gy.shape + n, c, h, w = x.shape gW = cuda.cupy.empty((out_c, c, self.kh, self.kw), dtype=self.W_dtype) x = cuda.cupy.ascontiguousarray(x) diff --git a/chainer/functions/connection/deconvolution_2d.py b/chainer/functions/connection/deconvolution_2d.py index 1a2460965192..b5c20beaf509 100644 --- a/chainer/functions/connection/deconvolution_2d.py +++ b/chainer/functions/connection/deconvolution_2d.py @@ -106,37 +106,37 @@ def check_type_forward(self, in_types): b_type.shape[0] == w_type.shape[1] ) - def forward_cpu(self, inputs): - self.retain_inputs((0, 1)) # only retain x and W - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - if not all([isinstance(i, numpy.ndarray) for i in inputs]): - if b is not None: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}, type(b): {2}' - .format(type(W), type(x), type(b))) - else: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}' - .format(type(W), type(x))) - + def _calc_out_size(self, x, W): + """Calculates and stores `outh` and `outw`.""" kh, kw = W.shape[2:] - _, _, h, w = x.shape - gcol = numpy.tensordot(W, x, (0, 1)).astype(x.dtype, copy=False) + _, _, in_h, in_w = x.shape # - k, m, n: shape of out_channel # - b: number of inputs # - h, w: height and width of kernels # k, m, n, b, h, w -> b, k, m, n, h, w - gcol = numpy.rollaxis(gcol, 3) if self.outh is None: - self.outh = conv.get_deconv_outsize(h, kh, self.sy, self.ph, - d=self.dy) - assert self.outh > 0, 'Height in the output should be positive.' + self.outh = conv.get_deconv_outsize( + in_h, kh, self.sy, self.ph, d=self.dy) + if self.outh <= 0: + raise RuntimeError('Height in the output must be positive.') + if self.outw is None: - self.outw = conv.get_deconv_outsize(w, kw, self.sx, self.pw, - d=self.dx) - assert self.outw > 0, 'Width in the output should be positive.' + self.outw = conv.get_deconv_outsize( + in_w, kw, self.sx, self.pw, d=self.dx) + if self.outw <= 0: + raise RuntimeError('Width in the output must be positive.') + + def forward_cpu(self, inputs): + self.retain_inputs((0, 1)) # only retain x and W + if len(inputs) == 2: + (x, W), b = inputs, None + else: + x, W, b = inputs + + self._calc_out_size(x, W) + + gcol = numpy.tensordot(W, x, (0, 1)).astype(x.dtype, copy=False) + gcol = numpy.rollaxis(gcol, 3) y = conv.col2im_cpu( gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw, dy=self.dy, dx=self.dx) @@ -147,91 +147,23 @@ def forward_cpu(self, inputs): def forward_gpu(self, inputs): self.retain_inputs((0, 1)) # only retain x and W - x, W = inputs[:2] - b = inputs[2] if len(inputs) == 3 else None - - if not all([isinstance(i, cuda.ndarray) for i in inputs]): - if b is not None: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}, type(b): {2}' - .format(type(W), type(x), type(b))) - else: - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}' - .format(type(W), type(x))) - - kh, kw = W.shape[2:] - n, in_c, in_h, in_w = x.shape - c = W.shape[1] # out_c - if self.outh is None: - self.outh = conv.get_deconv_outsize(in_h, kh, self.sy, self.ph, - d=self.dy) - assert self.outh > 0, 'Height in the output should be positive.' - if self.outw is None: - self.outw = conv.get_deconv_outsize(in_w, kw, self.sx, self.pw, - d=self.dx) - assert self.outw > 0, 'Width in the output should be positive.' + if len(inputs) == 2: + (x, W), b = inputs, None + else: + x, W, b = inputs + self._calc_out_size(x, W) self._set_cover_all(x, W) if (not self.cover_all and chainer.should_use_cudnn('>=auto') and x.dtype == W.dtype and ((self.dy == 1 and self.dx == 1) or _cudnn_version_ >= 6000)): - x = cuda.cupy.ascontiguousarray(x) - W = cuda.cupy.ascontiguousarray(W) - if b is not None: - b = cuda.cupy.ascontiguousarray(b) - - use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype) - handle = cudnn.get_handle() - x_desc = cudnn.create_tensor_descriptor(x) - y = cuda.cupy.empty((n, c, self.outh, self.outw), - dtype=x.dtype) - y_desc = cudnn.create_tensor_descriptor(y) + # cuDNN implementation + return self._forward_cudnn(x, W, b) - filter_desc = cudnn.create_filter_descriptor(W) - conv_param = (self.ph, self.pw), (self.sy, self.sx), x.dtype - conv_desc = cudnn.create_convolution_descriptor( - *conv_param, dilation=(self.dy, self.dx), - use_tensor_core=use_tensor_core) - if b is not None: - bias_desc = cudnn.create_tensor_descriptor( - b[None, :, None, None]) - - oz_dtype = 'd' if x.dtype == 'd' else 'f' - one = numpy.array(1, dtype=oz_dtype).ctypes - zero = numpy.array(0, dtype=oz_dtype).ctypes - - workspace_size = cuda.get_max_workspace_size() - workspace = cuda.cupy.empty((workspace_size,), dtype='b') - - if configuration.config.cudnn_deterministic: - algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 - elif configuration.config.autotune and _cudnn_version_ >= 5000: - algo = get_algorithm(W, x, y, conv_param, handle, filter_desc, - x_desc, conv_desc, y_desc, workspace) - else: - algo = libcudnn.getConvolutionBackwardDataAlgorithm( - handle, filter_desc.value, x_desc.value, conv_desc.value, - y_desc.value, _bwd_data_pref, workspace_size) - - if use_tensor_core: - # Only CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 supports - # Tensor-Core in cuDNN7 - algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 - - libcudnn.convolutionBackwardData_v3( - handle, one.data, filter_desc.value, W.data.ptr, - x_desc.value, x.data.ptr, conv_desc.value, - algo, workspace.data.ptr, workspace_size, - zero.data, y_desc.value, y.data.ptr) - - if b is not None: - cudnn.add_tensor( - handle, one.data, bias_desc.value, b.data.ptr, - one.data, y_desc.value, y.data.ptr) else: + # Implementation using col2im gcol = cuda.cupy.tensordot(W, x, (0, 1)).astype(x.dtype, copy=False) # - k, m, n: shape of out_channel @@ -244,6 +176,67 @@ def forward_gpu(self, inputs): dy=self.dy, dx=self.dx) if b is not None: y += b.reshape(1, b.size, 1, 1) + return y, + + def _forward_cudnn(self, x, W, b): + x = cuda.cupy.ascontiguousarray(x) + W = cuda.cupy.ascontiguousarray(W) + if b is not None: + b = cuda.cupy.ascontiguousarray(b) + + in_c, out_c, kh, kw = W.shape + n, _, in_h, in_w = x.shape + + use_tensor_core = chainer.should_use_cudnn_tensor_core(x.dtype) + + handle = cudnn.get_handle() + x_desc = cudnn.create_tensor_descriptor(x) + y = cuda.cupy.empty((n, out_c, self.outh, self.outw), + dtype=x.dtype) + y_desc = cudnn.create_tensor_descriptor(y) + + filter_desc = cudnn.create_filter_descriptor(W) + conv_param = (self.ph, self.pw), (self.sy, self.sx), x.dtype + conv_desc = cudnn.create_convolution_descriptor( + *conv_param, dilation=(self.dy, self.dx), + use_tensor_core=use_tensor_core) + if b is not None: + bias_desc = cudnn.create_tensor_descriptor( + b[None, :, None, None]) + + oz_dtype = 'd' if x.dtype == 'd' else 'f' + one = numpy.array(1, dtype=oz_dtype).ctypes + zero = numpy.array(0, dtype=oz_dtype).ctypes + + workspace_size = cuda.get_max_workspace_size() + workspace = cuda.cupy.empty((workspace_size,), dtype='b') + + if configuration.config.cudnn_deterministic: + algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 + elif configuration.config.autotune and _cudnn_version_ >= 5000: + algo = get_algorithm(W, x, y, conv_param, handle, filter_desc, + x_desc, conv_desc, y_desc, workspace) + else: + algo = libcudnn.getConvolutionBackwardDataAlgorithm( + handle, filter_desc.value, x_desc.value, conv_desc.value, + y_desc.value, _bwd_data_pref, workspace_size) + + if use_tensor_core: + # Only CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 supports + # Tensor-Core in cuDNN7 + algo = libcudnn.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 + + libcudnn.convolutionBackwardData_v3( + handle, one.data, filter_desc.value, W.data.ptr, + x_desc.value, x.data.ptr, conv_desc.value, + algo, workspace.data.ptr, workspace_size, + zero.data, y_desc.value, y.data.ptr) + + if b is not None: + cudnn.add_tensor( + handle, one.data, bias_desc.value, b.data.ptr, + one.data, y_desc.value, y.data.ptr) + return y, def backward(self, indexes, grad_outputs): From ea3b114445c454c949ad27e82b5b5b4e61ccea91 Mon Sep 17 00:00:00 2001 From: niboshi Date: Mon, 13 Nov 2017 11:46:00 +0900 Subject: [PATCH 5/7] Allow None as inputs --- chainer/function_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainer/function_node.py b/chainer/function_node.py index 9227ef10ba5c..f1073fc4daf7 100644 --- a/chainer/function_node.py +++ b/chainer/function_node.py @@ -285,7 +285,7 @@ def apply(self, inputs): def _check_data_type_forward(self, in_data): xp = cuda.get_array_module(*in_data) - if not all([isinstance(_, xp.ndarray) for _ in in_data]): + if not all([_ is None or isinstance(_, xp.ndarray) for _ in in_data]): raise ValueError('numpy and cupy must not be used together\n' '{}' .format(', '.join(str(type(_)) for _ in in_data))) From 42a45742376f203ab77253ee5aaa73cdf4d70098 Mon Sep 17 00:00:00 2001 From: niboshi Date: Tue, 14 Nov 2017 16:50:37 +0900 Subject: [PATCH 6/7] Update VariableNode.data if new data is assigned --- chainer/variable.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chainer/variable.py b/chainer/variable.py index 5f5bc7f26938..de191180e65d 100644 --- a/chainer/variable.py +++ b/chainer/variable.py @@ -389,6 +389,9 @@ def _set_data_type(self, d): self.dtype = d.dtype self.shape = d.shape + if self._data is not None: + self._data = d + def _check_old_style_gradient(self): if self._old_style_grad_generator is not None: raise RuntimeError( From d773b9718c44236426659d4e15eddb579d8c91e7 Mon Sep 17 00:00:00 2001 From: niboshi Date: Fri, 10 Nov 2017 14:21:03 +0900 Subject: [PATCH 7/7] iDeep --- chainer/__init__.py | 48 ++++ chainer/_ideep.py | 64 +++++ chainer/cuda.py | 2 +- chainer/function_node.py | 3 +- chainer/functions/activation/relu.py | 90 ++++-- .../functions/connection/convolution_2d.py | 37 ++- .../functions/connection/deconvolution_2d.py | 36 ++- chainer/functions/connection/linear.py | 143 +++++++++- chainer/testing/attr.py | 2 + chainer/testing/backend.py | 6 + chainer/variable.py | 4 +- .../activation_tests/test_relu.py | 130 +++++---- .../connection_tests/test_convolution_2d.py | 270 ++++++++---------- .../connection_tests/test_deconvolution_2d.py | 14 +- 14 files changed, 590 insertions(+), 259 deletions(-) create mode 100644 chainer/_ideep.py diff --git a/chainer/__init__.py b/chainer/__init__.py index a82d26140bd1..faba15603640 100644 --- a/chainer/__init__.py +++ b/chainer/__init__.py @@ -3,6 +3,8 @@ import threading import warnings +import numpy + from chainer import _version from chainer import configuration # NOQA from chainer import cuda # NOQA @@ -62,6 +64,7 @@ from chainer import _environment_check +from chainer import _ideep # Check environment conditions @@ -71,6 +74,7 @@ __version__ = _version.__version__ _thread_local = threading.local() +_array_types = None def get_function_hooks(): @@ -82,6 +86,49 @@ def get_function_hooks(): return ret +def _load_array_types(): + global _array_types + global _cpu_array_types + + if _array_types is None: + _array_types = [numpy.ndarray] + _cpu_array_types = [numpy.ndarray] + + if cuda.available: + _array_types.append(cuda.ndarray) + + if _ideep.is_available(): + _array_types.append(_ideep.ideep.mdarray) + _cpu_array_types.append(_ideep.ideep.mdarray) + + _array_types = tuple(_array_types) + _cpu_array_types = tuple(_cpu_array_types) + + +def get_array_types(): + _load_array_types() + return _array_types + + +def get_cpu_array_types(): + _load_array_types() + return _cpu_array_types + + +def is_arrays_compatible(arrays): + arrays = [_ for _ in arrays if _ is not None] + if len(arrays) == 0: + return True + if type(arrays[0]) is cuda.ndarray: + types = cuda.ndarray + else: + if _ideep.is_available(): + types = (numpy.ndarray, _ideep.ideep.mdarray) + else: + types = numpy.ndarray + return all(isinstance(_, types) for _ in arrays) + + global_config.debug = bool(int(os.environ.get('CHAINER_DEBUG', '0'))) global_config.cudnn_deterministic = False global_config.enable_backprop = True @@ -92,6 +139,7 @@ def get_function_hooks(): global_config.use_cudnn = os.environ.get('CHAINER_USE_CUDNN', 'auto') global_config.use_cudnn_tensor_core = 'auto' global_config.autotune = False +global_config.use_ideep = os.environ.get('CHAINER_USE_IDEEP', 'auto') _SHOULD_USE_CUDNN = { diff --git a/chainer/_ideep.py b/chainer/_ideep.py new file mode 100644 index 000000000000..fe403514962a --- /dev/null +++ b/chainer/_ideep.py @@ -0,0 +1,64 @@ +_error = None + +from chainer.configuration import config + + +try: + import ideep # NOQA +except ImportError as e: + _error = e + + +_SHOULD_USE_IDEEP = { + '==always': {'always': True, 'auto': False, 'never': False}, + '>=auto': {'always': True, 'auto': True, 'never': False}, +} + + +def is_available(): + return _error is None + + +def check_available(): + """Checks if iDeep is available. + + When iDeep is correctly set up, nothing happens. + Otherwise it raises ``RuntimeError``. + """ + if _error is not None: + raise RuntimeError( + 'iDeep is not available.\n' + 'Reason: {}'.format(type(_error).__name__, str(_error))) + + +def should_use_ideep(level): + """Determines if we should use iDeep. + + This function checks ``chainer.config.use_ideep`` and availability + of ``ideep`` package. + + Args: + level (str): iDeep use level. It must be either ``'==always'`` or + ``'>=auto'``. ``'==always'`` indicates that the ``use_ideep`` + config must be ``'always'`` to use iDeep. + + Returns: + bool: ``True`` if the caller should use cuDNN. + + """ + if not is_available(): + return False + + if level not in _SHOULD_USE_IDEEP: + raise ValueError('invalid iDeep use level: %s ' + '(must be either of "==always" or ">=auto")' % + repr(level)) + + flags = _SHOULD_USE_IDEEP[level] + + use_ideep = config.use_ideep + if use_ideep not in flags: + raise ValueError('invalid use_ideep configuration: %s ' + '(must be either of "always", "auto", or "never")' % + repr(use_ideep)) + return flags[use_ideep] diff --git a/chainer/cuda.py b/chainer/cuda.py index cc739f81590b..dbced6ae3d9b 100644 --- a/chainer/cuda.py +++ b/chainer/cuda.py @@ -362,7 +362,7 @@ def _array_to_cpu(array, stream): return array.get(stream) elif isinstance(array, (numpy.number, numpy.bool_)): return numpy.asarray(array) - elif isinstance(array, numpy.ndarray): + elif isinstance(array, chainer.get_cpu_array_types()): return array else: raise TypeError( diff --git a/chainer/function_node.py b/chainer/function_node.py index f1073fc4daf7..e599b3aba357 100644 --- a/chainer/function_node.py +++ b/chainer/function_node.py @@ -284,8 +284,7 @@ def apply(self, inputs): return ret def _check_data_type_forward(self, in_data): - xp = cuda.get_array_module(*in_data) - if not all([_ is None or isinstance(_, xp.ndarray) for _ in in_data]): + if not chainer.is_arrays_compatible(in_data): raise ValueError('numpy and cupy must not be used together\n' '{}' .format(', '.join(str(type(_)) for _ in in_data))) diff --git a/chainer/functions/activation/relu.py b/chainer/functions/activation/relu.py index e430d755b3e4..fdd62aa084d0 100644 --- a/chainer/functions/activation/relu.py +++ b/chainer/functions/activation/relu.py @@ -1,6 +1,7 @@ import numpy import chainer +from chainer import _ideep from chainer import cuda from chainer import function_node from chainer import utils @@ -17,6 +18,7 @@ class ReLU(function_node.FunctionNode): """Rectified Linear Unit.""" _use_cudnn = False + _ideep_hint = None def check_type_forward(self, in_types): type_check.expect( @@ -24,33 +26,60 @@ def check_type_forward(self, in_types): in_types[0].dtype.kind == 'f', ) - def forward_cpu(self, x): + def forward_cpu(self, inputs): + x, = inputs + if (x.dtype == numpy.float32 + and (x.ndim == 2 or x.ndim == 4) + and _ideep.should_use_ideep('>=auto')): + + # iDeep implementation + return self.forward_ideep(inputs) + + self.retain_outputs((0,)) + return utils.force_array(numpy.maximum(x, 0, dtype=x.dtype)), + + def forward_ideep(self, inputs): + self.retain_inputs((0,)) self.retain_outputs((0,)) - return utils.force_array(numpy.maximum(x[0], 0, dtype=x[0].dtype)), - def forward_gpu(self, x): - if chainer.should_use_cudnn('==always') and x[0].flags.c_contiguous: + cc = _ideep.ideep.xnn.ReLUForward(inputs) + self._ideep_hint = cc.hint + + y, = cc.execute_on() + return y, + + def forward_gpu(self, inputs): + x, = inputs + if chainer.should_use_cudnn('==always') and x.flags.c_contiguous: # cupy.activation_backward requires the input. # So, we retain it for backward computation. self.retain_inputs((0,)) self._use_cudnn = True - y = cudnn.activation_forward(x[0], _mode) + y = cudnn.activation_forward(x, _mode) else: - y = cuda.cupy.maximum(x[0], 0) + y = cuda.cupy.maximum(x, 0) self.retain_outputs((0,)) return y, - def backward(self, indexes, gy): - y = self.get_retained_outputs()[0] + def backward(self, indexes, grad_outputs): + gy, = grad_outputs + y, = self.get_retained_outputs() + if self._ideep_hint is not None: + x, = self.get_retained_inputs() + return ReLUGradIdeep(x, y, self._ideep_hint).apply((gy,)) if chainer.should_use_cudnn('==always') and self._use_cudnn: - x = self.get_retained_inputs()[0] - return ReLUGrad3(x, y).apply((gy[0],)) + x, = self.get_retained_inputs() + return ReLUGradCudnn(x, y).apply((gy,)) else: - return ReLUGrad2(y).apply((gy[0],)) + return ReLUGrad2(y).apply((gy,)) def _heaviside(x): - return (x > 0).astype(x.dtype) + if _ideep.is_available() and isinstance(x, _ideep.ideep.mdarray): + # ideep.mdarray does not support __gt__ yet + return numpy.greater(x, 0).astype(x.dtype) + else: + return (x > 0).astype(x.dtype) class ReLUGrad2(function_node.FunctionNode): @@ -84,7 +113,7 @@ def backward(self, indexes, gy): return gy[0] * _heaviside(self.b), -class ReLUGrad3(function_node.FunctionNode): +class ReLUGrad3Base(function_node.FunctionNode): """Computes the gradient of the ReLU function. This function takes 3 variables a, b, and c, and @@ -96,20 +125,35 @@ class ReLUGrad3(function_node.FunctionNode): we do not backpropagate errors toward them for computational efficiency. """ - def __init__(self, a, b): - super(ReLUGrad3, self).__init__() - self.a = a.data - self.b = b.data + def __init__(self, x, y): + super(ReLUGrad3Base, self).__init__() + self.x = x.data + self.y = y.data - def forward_cpu(self, inputs): - return (self.b > 0) * inputs[0], + def backward(self, indexes, grad_outputs): + gy, = grad_outputs + ggx = gy * _heaviside(self.y) + return ggx, - def forward_gpu(self, inputs): + +class ReLUGradCudnn(ReLUGrad3Base): + + def forward(self, inputs): assert chainer.should_use_cudnn('==always') - return cudnn.activation_backward(self.a, self.b, inputs[0], _mode), + gy, = inputs + return cudnn.activation_backward(self.x, self.y, gy, _mode), - def backward(self, indexes, gy): - return gy[0] * _heaviside(self.b), + +class ReLUGradIdeep(ReLUGrad3Base): + + def __init__(self, x, y, hint): + super(ReLUGradIdeep, self).__init__(x, y) + self.hint = hint + + def forward(self, inputs): + cc = _ideep.ideep.xnn.ReLUBackward((self.x,), inputs, self.hint) + ggx, = cc.execute_on() + return ggx, def relu(x): diff --git a/chainer/functions/connection/convolution_2d.py b/chainer/functions/connection/convolution_2d.py index 820459881bd2..786036c2a60d 100644 --- a/chainer/functions/connection/convolution_2d.py +++ b/chainer/functions/connection/convolution_2d.py @@ -1,6 +1,7 @@ import numpy import chainer +from chainer import _ideep from chainer import configuration from chainer import cuda from chainer import function_node @@ -9,6 +10,7 @@ from chainer.utils import conv from chainer.utils import type_check + if cuda.cudnn_enabled: cudnn = cuda.cudnn libcudnn = cuda.cuda.cudnn @@ -58,6 +60,8 @@ def get_algorithm_bwd_filter( class Convolution2DFunction(function_node.FunctionNode): + ideep_hint = None + def __init__(self, stride=1, pad=0, cover_all=False, **kwargs): argument.check_unexpected_kwargs( kwargs, @@ -113,8 +117,13 @@ def _get_out_size(self, inputs): return out_h, out_w def forward_cpu(self, inputs): - self.retain_inputs((0, 1)) # retain only x and W + if (all(_.dtype == numpy.float32 for _ in inputs) + and _ideep.should_use_ideep('>=auto')): + + # iDeep implementation + return self.forward_ideep(inputs) + self.retain_inputs((0, 1)) # retain only x and W if len(inputs) == 2: (x, W), b = inputs, None else: @@ -130,6 +139,17 @@ def forward_cpu(self, inputs): y += b return numpy.rollaxis(y, 3, 1), + def forward_ideep(self, inputs): + self.retain_inputs((0, 1)) + + cc = _ideep.ideep.xnn.ConvolutionForward( + inputs, stride=(self.sy, self.sx), + pad=(self.ph, self.pw), cover_all=self.cover_all) + self.ideep_hint = cc.hint + + y, = cc.execute_on() + return y, + def forward_gpu(self, inputs): self.retain_inputs((0, 1)) # retain only x and W if len(inputs) == 2: @@ -251,8 +271,12 @@ def __init__(self, conv2d): self.dx = conv2d.dx self.cover_all = conv2d.cover_all self.W_dtype = W_node.dtype + self.ideep_hint = conv2d.ideep_hint def forward_cpu(self, inputs): + if self.ideep_hint is not None: + return self.forward_ideep(inputs) + self.retain_inputs((0, 1)) x, gy = inputs col = conv.im2col_cpu( @@ -270,6 +294,17 @@ def forward_cpu(self, inputs): gy, col, ((0, 2, 3), (0, 4, 5))).astype(self.W_dtype, copy=False) return gW, + def forward_ideep(self, inputs): + self.retain_inputs((0, 1)) + + cc = _ideep.ideep.xnn.ConvolutionBackwardWeights( + inputs, stride=(self.sy, self.sx), pad=(self.ph, self.pw), + outsize=(self.kh, self.kw), cover_all=self.cover_all, + hint=self.ideep_hint) + gW, gb = cc.execute_on() + + return gW, + def forward_gpu(self, inputs): self.retain_inputs((0, 1)) x, gy = inputs diff --git a/chainer/functions/connection/deconvolution_2d.py b/chainer/functions/connection/deconvolution_2d.py index b5c20beaf509..d36c680c96c2 100644 --- a/chainer/functions/connection/deconvolution_2d.py +++ b/chainer/functions/connection/deconvolution_2d.py @@ -1,6 +1,7 @@ import numpy import chainer +from chainer import _ideep from chainer import configuration from chainer import cuda from chainer import function_node @@ -45,6 +46,7 @@ def _pair(x): class Deconvolution2DFunction(function_node.FunctionNode): cover_all = None + ideep_hint = None def __init__(self, stride=1, pad=0, outsize=None, **kwargs): argument.check_unexpected_kwargs( @@ -135,14 +137,36 @@ def forward_cpu(self, inputs): self._calc_out_size(x, W) - gcol = numpy.tensordot(W, x, (0, 1)).astype(x.dtype, copy=False) - gcol = numpy.rollaxis(gcol, 3) - y = conv.col2im_cpu( - gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw, - dy=self.dy, dx=self.dx) - # b, k, h, w + if (all(_.dtype == numpy.float32 for _ in inputs) + and _ideep.should_use_ideep('>=auto')): + + # iDeep implementation + return self._forward_ideep(x, W, b) + + else: + gcol = numpy.tensordot(W, x, (0, 1)).astype(x.dtype, copy=False) + gcol = numpy.rollaxis(gcol, 3) + y = conv.col2im_cpu( + gcol, self.sy, self.sx, self.ph, self.pw, self.outh, self.outw, + dy=self.dy, dx=self.dx) + # b, k, h, w + if b is not None: + y += b.reshape(1, b.size, 1, 1) + return y, + + def _forward_ideep(self, x, W, b): + # bias is not supported yet + cc = _ideep.ideep.xnn.ConvolutionBackwardData( + (x, W), stride=(self.sy, self.sx), + pad=(self.ph, self.pw), outsize=(self.outh, self.outw), + cover_all=self.cover_all) + + self.ideep_hint = cc.hint + y, = cc.execute_on() + if b is not None: y += b.reshape(1, b.size, 1, 1) + return y, def forward_gpu(self, inputs): diff --git a/chainer/functions/connection/linear.py b/chainer/functions/connection/linear.py index 9170798a051d..ba0c6fdd3e0e 100644 --- a/chainer/functions/connection/linear.py +++ b/chainer/functions/connection/linear.py @@ -1,5 +1,7 @@ import numpy +from chainer import _ideep +from chainer import cuda from chainer import function_node import chainer.functions from chainer.utils import type_check @@ -7,6 +9,8 @@ class LinearFunction(function_node.FunctionNode): + _ideep_hint = None + def check_type_forward(self, in_types): n_in = in_types.size() type_check.expect(2 <= n_in, n_in <= 3) @@ -28,13 +32,18 @@ def check_type_forward(self, in_types): ) def forward(self, inputs): - x = inputs[0] - W = inputs[1] + if (all(_.dtype == numpy.float32 for _ in inputs) + and cuda.get_array_module(*inputs) is numpy + and _ideep.should_use_ideep('>=auto')): + + # iDeep implementation + return self._forward_ideep(inputs) - if not type_check.same_types(*inputs): - raise ValueError('numpy and cupy must not be used together\n' - 'type(W): {0}, type(x): {1}' - .format(type(W), type(x))) + # Generic implementation + if len(inputs) == 3: + x, W, b = inputs + else: + (x, W), b = inputs, None # NumPy raises an error when the array is not contiguous. # See: https://github.com/chainer/chainer/issues/2744 @@ -45,23 +54,129 @@ def forward(self, inputs): x = numpy.ascontiguousarray(x) y = x.dot(W.T).astype(x.dtype, copy=False) - if len(inputs) == 3: - b = inputs[2] + if b is not None: y += b self.retain_inputs((0, 1)) # b is not retained return y, + def _forward_ideep(self, inputs): + self.retain_inputs((0, 1)) + + cc = _ideep.ideep.xnn.LinearForward(inputs) + self._ideep_hint = cc.hint + self._ideep_W = cc.W + + y, = cc.execute_on() + y.reset_buf_order() + + return y, + def backward(self, indexes, grad_outputs): - x, W = self.get_retained_inputs() - gy, = grad_outputs + ret = [] + + # TODO(nibosh): If `2 in indexes`, it does not work. + if (self._ideep_hint is not None + and 2 not in indexes): + + # iDeep implementation + inputs = self.get_retained_inputs() + input_data = tuple([_.data for _ in inputs]) + + if 0 in indexes: + gx = LinearGradDIdeep( + input_data, self._ideep_hint, self._ideep_W).apply( + grad_outputs) + ret.append(gx[0]) + if 1 in indexes or 2 in indexes: + gW_b = LinearGradWIdeep( + input_data, self._ideep_hint).apply(grad_outputs) + if 1 in indexes: + ret.append(gW_b[0]) + if 2 in indexes: + ret.append(gW_b[1]) + else: + # Generic implementation + x, W = self.get_retained_inputs() + gy, = grad_outputs + if 0 in indexes: + gx = linear(gy, W.T) + ret.append(chainer.functions.cast(gx, x.dtype)) + if 1 in indexes: + gW = linear(gy.T, x.T) + ret.append(chainer.functions.cast(gW, W.dtype)) + if 2 in indexes: + gb = chainer.functions.sum(gy, axis=0) + ret.append(gb) + + return ret + + +class LinearGradDIdeep(function_node.FunctionNode): + + def __init__(self, inputs, hint, ccW): + super(LinearGradDIdeep, self).__init__() + + assert len(inputs) >= 2 + self.inputs = inputs + self.W = ccW + self.hint = hint + + def forward_cpu(self, inputs): + cc = _ideep.ideep.xnn.LinearBackwardData( + self.inputs, inputs, self.hint, self.W) + + gx = cc.execute_on() + gx[0].reset_buf_order() + + return gx + + def backward(self, indexes, gy): + x = self.inputs[0] + W = self.inputs[1] + + ret = [] + if 0 in indexes: + gx = linear(gy, W) + ret.append(gx) + if 1 in indexes: + gW = linear(gy, x) + ret.append(gW) + if 2 in indexes: + gb = chainer.functions.sum(gy, axis=0) + ret.append(gb) + + return ret + + +class LinearGradWIdeep(function_node.FunctionNode): + + def __init__(self, inputs, hint): + super(LinearGradWIdeep, self).__init__() + + assert len(inputs) >= 2 + self.inputs = inputs + self.hint = hint + + def forward_cpu(self, inputs): + cc = _ideep.ideep.xnn.LinearBackwardWeighs( + self.inputs, inputs, self.hint) + + gW_b = cc.execute_on() + gW_b[0].reset_buf_order() + + return gW_b + + def backward(self, indexes, gy): + x = self.inputs[0] + W = self.inputs[1] ret = [] if 0 in indexes: - gx = linear(gy, W.T) - ret.append(chainer.functions.cast(gx, x.dtype)) + gx = linear(gy, W) + ret.append(gx) if 1 in indexes: - gW = linear(gy.T, x.T) - ret.append(chainer.functions.cast(gW, W.dtype)) + gW = linear(gy, x) + ret.append(gW) if 2 in indexes: gb = chainer.functions.sum(gy, axis=0) ret.append(gb) diff --git a/chainer/testing/attr.py b/chainer/testing/attr.py index 6fe24e7392cb..1a779468eb41 100644 --- a/chainer/testing/attr.py +++ b/chainer/testing/attr.py @@ -29,6 +29,7 @@ def get_error(): _gpu_limit = int(os.getenv('CHAINER_TEST_GPU_LIMIT', '-1')) cudnn = pytest.mark.cudnn + ideep = pytest.mark.ideep slow = pytest.mark.slow else: @@ -37,6 +38,7 @@ def _dummy_callable(*args, **kwargs): assert False # Not reachable cudnn = _dummy_callable + ideep = _dummy_callable slow = _dummy_callable diff --git a/chainer/testing/backend.py b/chainer/testing/backend.py index 8026908d17b4..c50dd33f65f8 100644 --- a/chainer/testing/backend.py +++ b/chainer/testing/backend.py @@ -15,6 +15,7 @@ class BackendConfig(object): ('use_cudnn', 'never'), ('cudnn_deterministic', False), ('autotune', False), + ('use_ideep', 'never'), ] def __init__(self, params): @@ -45,6 +46,8 @@ def __enter__(self): 'cudnn_deterministic', self.cudnn_deterministic), chainer.using_config( 'autotune', self.autotune), + chainer.using_config( + 'use_ideep', self.use_ideep), ] for c in self._contexts: c.__enter__() @@ -80,6 +83,9 @@ def get_pytest_marks(self): marks.append(attr.gpu) if self.use_cudnn != 'never': marks.append(attr.cudnn) + else: + if self.use_ideep != 'never': + marks.append(attr.ideep) assert all(callable(_) for _ in marks) return marks diff --git a/chainer/variable.py b/chainer/variable.py index de191180e65d..84794a75bbd2 100644 --- a/chainer/variable.py +++ b/chainer/variable.py @@ -18,7 +18,7 @@ def _check_grad_type(func, x, gx): if x.data is None or gx is None: # ``x.data is None`` implies that the data array is not retained return - if not isinstance(gx, type(x.data)): + if not chainer.is_arrays_compatible((gx, x.data)): msg = ('Type of data and grad mismatch\n%s != %s' % (type(x.data), type(gx))) typ = TypeError @@ -451,7 +451,7 @@ def __init__(self, data=None, **kwargs): ('requires_grad', True)) if (data is not None and - not isinstance(data, (numpy.ndarray, cuda.ndarray))): + not isinstance(data, chainer.get_array_types())): msg = '''numpy.ndarray or cuda.ndarray are expected. Actual: {0}'''.format(type(data)) raise TypeError(msg) diff --git a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py index f754fd1c33af..3e429b5fffe0 100644 --- a/tests/chainer_tests/functions_tests/activation_tests/test_relu.py +++ b/tests/chainer_tests/functions_tests/activation_tests/test_relu.py @@ -9,103 +9,111 @@ from chainer import gradient_check from chainer import testing from chainer.testing import attr +from chainer.testing import backend + + +def _to_noncontiguous(arrays): + xp = cuda.get_array_module(*arrays) + return [xp.asfortranarray(a) for a in arrays] @testing.parameterize(*testing.product({ 'shape': [(3, 2), ()], 'dtype': [numpy.float16, numpy.float32, numpy.float64], + 'c_contiguous': [True, False], })) @testing.fix_random() +@backend.inject_backend_tests( + ['test_forward', 'test_backward', 'test_double_backward'], + # CPU tests + testing.product({ + 'use_cuda': [False], + 'use_ideep': ['never', 'always'], + }) + # GPU tests + + testing.product({ + 'use_cuda': [True], + 'use_cudnn': ['never', 'always'], + })) class TestReLU(unittest.TestCase): def setUp(self): # Avoid unstability of numerical grad - self.x = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) - self.x[(-0.1 < self.x) & (self.x < 0.1)] = 0.5 - self.gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) - self.ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + x = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + x[(-0.1 < x) & (x < 0.1)] = 0.5 + gy = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + ggx = numpy.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.inputs = [x] + self.grad_outputs = [gy] + self.grad_grad_inputs = [ggx] self.check_backward_options = {} self.check_double_backward_options = {} if self.dtype == numpy.float16: self.check_double_backward_options = {'atol': 1e-3, 'rtol': 1e-2} - def check_forward(self, x_data, use_cudnn='always'): - x = chainer.Variable(x_data) - with chainer.using_config('use_cudnn', use_cudnn): - y = functions.relu(x) - self.assertEqual(y.data.dtype, self.dtype) - - expected = self.x.copy() + def forward_cpu(self, inputs): + x, = inputs + expected = x.copy() expected[expected < 0] = 0 + return expected, + + def check_forward(self, inputs, backend_config): + y_expected, = self.forward_cpu(inputs) - testing.assert_allclose(expected, y.data) + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + if not self.c_contiguous: + inputs = _to_noncontiguous(inputs) - def test_forward_cpu(self): - self.check_forward(self.x) + x_data, = inputs + x = chainer.Variable(x_data) + with backend_config: + y = functions.relu(x) + assert y.data.dtype == self.dtype - @attr.gpu - def test_forward_gpu(self): - self.check_forward(cuda.to_gpu(self.x)) + testing.assert_allclose(y_expected, y.data) - @attr.gpu - def test_forward_gpu_no_cudnn(self): - self.check_forward(cuda.to_gpu(self.x), 'never') + def test_forward(self, backend_config): + self.check_forward(self.inputs, backend_config) - def check_backward(self, x_data, y_grad, use_cudnn='always'): - with chainer.using_config('use_cudnn', use_cudnn): + def check_backward(self, inputs, grad_outputs, backend_config): + x_data, = inputs + y_grad, = grad_outputs + with backend_config: gradient_check.check_backward( functions.relu, x_data, y_grad, dtype=numpy.float64, **self.check_backward_options) - def test_backward_cpu(self): - self.check_backward(self.x, self.gy) - - @attr.gpu - def test_backward_gpu(self): - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + def test_backward(self, backend_config): + self.check_backward(self.inputs, self.grad_outputs, backend_config) - @attr.gpu - def test_backward_gpu_non_contiguous(self): - self.check_backward(cuda.cupy.asfortranarray(cuda.to_gpu(self.x)), - cuda.cupy.asfortranarray(cuda.to_gpu(self.gy))) + def check_double_backward( + self, inputs, grad_outputs, grad_grad_inputs, backend_config): + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + grad_outputs = cuda.to_gpu(grad_outputs) + grad_grad_inputs = cuda.to_gpu(grad_grad_inputs) + if not self.c_contiguous: + inputs = _to_noncontiguous(inputs) + grad_outputs = _to_noncontiguous(grad_outputs) + grad_grad_inputs = _to_noncontiguous(grad_grad_inputs) - @attr.gpu - def test_backward_cpu_no_cudnn(self): - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy), 'never') - - def check_double_backward(self, x_data, y_grad, x_grad_grad, - use_cudnn='always'): def f(x): x = functions.relu(x) return x * x - with chainer.using_config('use_cudnn', use_cudnn): + x, = inputs + gy, = grad_outputs + ggx, = grad_grad_inputs + with backend_config: gradient_check.check_double_backward( - f, x_data, y_grad, x_grad_grad, dtype=numpy.float64, + f, x, gy, ggx, dtype=numpy.float64, **self.check_double_backward_options) - def test_double_backward_cpu(self): - self.check_double_backward(self.x, self.gy, self.ggx) - - @attr.gpu - def test_double_backward_gpu(self): - self.check_double_backward(cuda.to_gpu(self.x), - cuda.to_gpu(self.gy), - cuda.to_gpu(self.ggx)) - - @attr.gpu - def test_double_backward_gpu_non_contiguous(self): + def test_double_backward(self, backend_config): self.check_double_backward( - cuda.cupy.asfortranarray(cuda.to_gpu(self.x)), - cuda.cupy.asfortranarray(cuda.to_gpu(self.gy)), - cuda.cupy.asfortranarray(cuda.to_gpu(self.ggx))) - - @attr.gpu - def test_double_backward_cpu_no_cudnn(self): - self.check_double_backward(cuda.to_gpu(self.x), - cuda.to_gpu(self.gy), - cuda.to_gpu(self.ggx), - 'never') + self.inputs, self.grad_outputs, self.grad_grad_inputs, + backend_config) @testing.parameterize(*testing.product({ diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py index a73e5689cc2d..8275b32fe6af 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_convolution_2d.py @@ -9,6 +9,7 @@ from chainer import gradient_check from chainer import testing from chainer.testing import attr +from chainer.testing import backend from chainer.testing import condition @@ -17,26 +18,44 @@ 'cover_all': [True, False], 'x_dtype': [numpy.float32], 'W_dtype': [numpy.float32], - 'cudnn_deterministic': [True, False], 'dilate': [1], - 'autotune': [True, False], + 'nobias': [True, False], }) + testing.product({ 'c_contiguous': [False], 'cover_all': [False], - 'cudnn_deterministic': [False], 'x_dtype': [numpy.float16, numpy.float32, numpy.float64], 'W_dtype': [numpy.float16, numpy.float32, numpy.float64], 'dilate': [1], - 'autotune': [False], + 'nobias': [True, False], }) + testing.product({ 'c_contiguous': [False], 'cover_all': [False], - 'cudnn_deterministic': [False], 'x_dtype': [numpy.float16, numpy.float32, numpy.float64], 'W_dtype': [numpy.float16, numpy.float32, numpy.float64], 'dilate': [2], - 'autotune': [False], + 'nobias': [True, False], }))) +@backend.inject_backend_tests( + ['test_forward', 'test_backward', 'test_double_backward'], + # CPU tests + testing.product({ + 'use_cuda': [False], + 'use_ideep': ['never', 'always'], + }) + # GPU tests + + testing.product([ + [{'use_cuda': True}], + + # Without cuDNN + testing.product({ + 'use_cudnn': ['never'], + }) + # With cuDNN + + testing.product({ + 'use_cudnn': ['always'], + 'cudnn_deterministic': [True, False], + 'autotune': [True, False], + })])) class TestConvolution2DFunction(unittest.TestCase): def setUp(self): @@ -46,79 +65,101 @@ def setUp(self): kh, kw = (3, 3) self.stride = 2 self.pad = (int(kh / 2) * self.dilate, int(kw / 2) * self.dilate) - self.use_cudnn = 'always' - self.W = numpy.random.normal( + W = numpy.random.normal( 0, numpy.sqrt(1. / (kh * kw * in_channels)), (out_channels, in_channels, kh, kw)).astype(self.W_dtype) - self.b = numpy.random.uniform( - -1, 1, out_channels).astype(self.x_dtype) - self.x = numpy.random.uniform( + if self.nobias: + b = None + else: + b = numpy.random.uniform( + -1, 1, out_channels).astype(self.x_dtype) + + x = numpy.random.uniform( -1, 1, (batches, in_channels, 4, 3)).astype(self.x_dtype) if self.cover_all: - self.gy = numpy.random.uniform(-1, 1, - (batches, out_channels, 3, 2) - ).astype(self.x_dtype) + gy = numpy.random.uniform( + -1, 1, (batches, out_channels, 3, 2)).astype(self.x_dtype) else: - self.gy = numpy.random.uniform( + gy = numpy.random.uniform( -1, 1, (batches, out_channels, 2, 2)).astype(self.x_dtype) - self.ggx = numpy.random.uniform(-1, 1, self.x.shape).astype( + ggx = numpy.random.uniform(-1, 1, x.shape).astype( self.x_dtype) - self.ggW = numpy.random.uniform(-1, 1, self.W.shape).astype( + ggW = numpy.random.uniform(-1, 1, W.shape).astype( self.W_dtype) - self.ggb = numpy.random.uniform(-1, 1, self.b.shape).astype( - self.x_dtype) - - @attr.gpu - def test_forward_consistency(self, nobias=False): - x_cpu = chainer.Variable(self.x) - W_cpu = chainer.Variable(self.W) - b_cpu = None if nobias else chainer.Variable(self.b) - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): + ggb = None if b is None else numpy.random.uniform( + -1, 1, b.shape).astype(self.x_dtype) + + self.inputs = [x, W, b] + self.grad_outputs = [gy] + self.grad_grad_inputs = [ggx, ggW, ggb] + + def _skip_if_unsupported(self, backend_config): + if backend_config.use_cudnn != 'never': + if backend_config.cudnn_deterministic: + if self.dilate != 1: + self.skipTest('unsupported') + + def forward_cpu(self, inputs): + x, W, b = inputs + x_cpu = chainer.Variable(x) + W_cpu = chainer.Variable(W) + b_cpu = None if b is None else chainer.Variable(b) + with chainer.using_config('use_ideep', 'never'): y_cpu = F.convolution_2d( x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, cover_all=self.cover_all, dilate=self.dilate) + return y_cpu, - x_gpu = chainer.Variable(cuda.to_gpu(self.x)) - W_gpu = chainer.Variable(cuda.to_gpu(self.W)) - b_gpu = None if nobias else chainer.Variable(cuda.to_gpu(self.b)) - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - with chainer.using_config('autotune', self.autotune): - y_gpu = F.convolution_2d( - x_gpu, W_gpu, b_gpu, stride=self.stride, pad=self.pad, - cover_all=self.cover_all, dilate=self.dilate) + def check_forward(self, inputs, backend_config): + y_expected, = self.forward_cpu(inputs) + + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + + x, W, b = inputs + x = chainer.Variable(x) + W = chainer.Variable(W) + b = None if b is None else chainer.Variable(b) + with backend_config: + y_actual = F.convolution_2d( + x, W, b, stride=self.stride, pad=self.pad, + cover_all=self.cover_all, dilate=self.dilate) testing.assert_allclose( - y_cpu.data, y_gpu.data.get(), atol=5e-4, rtol=5e-3) + y_expected.data, y_actual.data.get(), atol=5e-4, rtol=5e-3) + + def test_forward(self, backend_config): + self._skip_if_unsupported(backend_config) - @attr.gpu - def test_forward_consistency_im2col(self): - self.use_cudnn = 'never' - self.test_forward_consistency() + # Forward test does not currently target CPU backend. + # It only tests for consistency between GPU and CPU computation. + if not backend_config.use_cuda: + return + self.check_forward(self.inputs, backend_config) - @attr.gpu - def test_forward_consistency_im2col_nobias(self): - self.use_cudnn = 'never' - self.test_forward_consistency(nobias=True) + def check_backward(self, inputs, grad_outputs, backend_config): - def check_backward(self, x_data, W_data, b_data, y_grad): - xp = cuda.get_array_module(x_data) + xp = backend_config.xp + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + grad_outputs = cuda.to_gpu(grad_outputs) + + x_data, W_data, b_data = inputs + y_grad, = grad_outputs if not self.c_contiguous: x_data = xp.asfortranarray(x_data) W_data = xp.asfortranarray(W_data) y_grad = xp.asfortranarray(y_grad) - self.assertFalse(x_data.flags.c_contiguous) - self.assertFalse(W_data.flags.c_contiguous) - self.assertFalse(y_grad.flags.c_contiguous) + assert not x_data.flags.c_contiguous + assert not W_data.flags.c_contiguous + assert not y_grad.flags.c_contiguous if b_data is not None: - b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b = xp.empty((len(b_data) * 2,), dtype=b_data.dtype) b[::2] = b_data b_data = b[::2] - self.assertFalse(b_data.flags.c_contiguous) + assert not b_data.flags.c_contiguous args = (x_data, W_data) if b_data is not None: @@ -129,50 +170,27 @@ def f(*args): cover_all=self.cover_all, dilate=self.dilate) - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - with chainer.using_config('autotune', self.autotune): - gradient_check.check_backward( - f, args, y_grad, dtype='d', atol=5e-4, rtol=5e-3) - - @condition.retry(3) - def test_backward_cpu(self): - self.check_backward(self.x, self.W, self.b, self.gy) - - @condition.retry(3) - def test_backward_cpu_nobias(self): - self.check_backward(self.x, self.W, None, self.gy) - - @attr.gpu - @condition.retry(3) - def test_backward_gpu(self): - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), - cuda.to_gpu(self.b), cuda.to_gpu(self.gy)) + with backend_config: + gradient_check.check_backward( + f, args, y_grad, dtype='d', atol=5e-4, rtol=5e-3) - @attr.gpu @condition.retry(3) - def test_backward_gpu_nobias(self): - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), - None, cuda.to_gpu(self.gy)) + def test_backward(self, backend_config): + self._skip_if_unsupported(backend_config) + self.check_backward(self.inputs, self.grad_outputs, backend_config) - @attr.gpu - @condition.retry(3) - def test_backward_gpu_im2col(self): - self.use_cudnn = 'never' - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), - cuda.to_gpu(self.b), cuda.to_gpu(self.gy)) + def check_double_backward( + self, inputs, grad_outputs, grad_grad_inputs, backend_config): + xp = backend_config.xp - @attr.gpu - @condition.retry(3) - def test_backward_gpu_im2col_nobias(self): - self.use_cudnn = 'never' - self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.W), - None, cuda.to_gpu(self.gy)) + if backend_config.use_cuda: + inputs = cuda.to_gpu(inputs) + grad_outputs = cuda.to_gpu(grad_outputs) + grad_grad_inputs = cuda.to_gpu(grad_grad_inputs) - def check_double_backward(self, x_data, W_data, b_data, y_grad, - x_grad_grad, W_grad_grad, b_grad_grad): - xp = cuda.get_array_module(x_data) + x_data, W_data, b_data = inputs + y_grad, = grad_outputs + x_grad_grad, W_grad_grad, b_grad_grad = grad_grad_inputs if not self.c_contiguous: x_data = xp.asfortranarray(x_data) @@ -180,21 +198,21 @@ def check_double_backward(self, x_data, W_data, b_data, y_grad, y_grad = xp.asfortranarray(y_grad) x_grad_grad = xp.asfortranarray(x_grad_grad) W_grad_grad = xp.asfortranarray(W_grad_grad) - self.assertFalse(x_data.flags.c_contiguous) - self.assertFalse(W_data.flags.c_contiguous) - self.assertFalse(y_grad.flags.c_contiguous) - self.assertFalse(x_grad_grad.flags.c_contiguous) - self.assertFalse(W_grad_grad.flags.c_contiguous) + assert not x_data.flags.c_contiguous + assert not W_data.flags.c_contiguous + assert not y_grad.flags.c_contiguous + assert not x_grad_grad.flags.c_contiguous + assert not W_grad_grad.flags.c_contiguous if b_data is not None: - b = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + b = xp.empty((len(b_data) * 2,), dtype=b_data.dtype) b[::2] = b_data b_data = b[::2] - self.assertFalse(b_data.flags.c_contiguous) + assert not b_data.flags.c_contiguous - ggb = xp.empty((len(b_data) * 2,), dtype=self.b.dtype) + ggb = xp.empty((len(b_data) * 2,), dtype=b_data.dtype) ggb[::2] = b_grad_grad b_grad_grad = ggb[::2] - self.assertFalse(b_grad_grad.flags.c_contiguous) + assert not b_grad_grad.flags.c_contiguous args = (x_data, W_data) grad_grads = (x_grad_grad, W_grad_grad) @@ -207,51 +225,17 @@ def f(*args): cover_all=self.cover_all, dilate=self.dilate) return y * y # make the function nonlinear - with chainer.using_config('use_cudnn', self.use_cudnn): - with chainer.using_config('cudnn_deterministic', - self.cudnn_deterministic): - gradient_check.check_double_backward( - f, args, y_grad, grad_grads, - dtype='d', atol=5e-3, rtol=5e-2) - - @condition.retry(3) - def test_double_backward_cpu(self): - self.check_double_backward(self.x, self.W, self.b, self.gy, - self.ggx, self.ggW, self.ggb) + with backend_config: + gradient_check.check_double_backward( + f, args, y_grad, grad_grads, + dtype='d', atol=5e-3, rtol=5e-2) @condition.retry(3) - def test_double_backward_cpu_nobias(self): - self.check_double_backward(self.x, self.W, None, self.gy, - self.ggx, self.ggW, None) - - def check_double_backward_gpu(self, bias=True, im2col=False): - if im2col: - self.use_cudnn = 'never' + def test_double_backward(self, backend_config): + self._skip_if_unsupported(backend_config) self.check_double_backward( - cuda.to_gpu(self.x), cuda.to_gpu(self.W), - cuda.to_gpu(self.b) if bias else None, - cuda.to_gpu(self.gy), cuda.to_gpu(self.ggx), cuda.to_gpu(self.ggW), - cuda.to_gpu(self.ggb) if bias else None) - - @attr.gpu - @condition.retry(3) - def test_double_backward_gpu(self): - self.check_double_backward_gpu() - - @attr.gpu - @condition.retry(3) - def test_double_backward_gpu_nobias(self): - self.check_double_backward_gpu(bias=False) - - @attr.gpu - @condition.retry(3) - def test_double_backward_gpu_im2col(self): - self.check_double_backward_gpu(im2col=True) - - @attr.gpu - @condition.retry(3) - def test_double_backward_gpu_im2col_nobias(self): - self.check_double_backward_gpu(bias=False, im2col=True) + self.inputs, self.grad_outputs, self.grad_grad_inputs, + backend_config) @testing.parameterize(*(testing.product({ diff --git a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py index 6b1e3bb8d560..f21919c85a08 100644 --- a/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py +++ b/tests/chainer_tests/functions_tests/connection_tests/test_deconvolution_2d.py @@ -42,9 +42,10 @@ def _pair(x): @backend.inject_backend_tests( ['test_forward', 'test_backward', 'test_double_backward'], # CPU tests - [{ - 'use_cuda': False, - }] + testing.product({ + 'use_cuda': [False], + 'use_ideep': ['never', 'always'], + }) # GPU tests + testing.product([ [{'use_cuda': True}], @@ -115,9 +116,10 @@ def forward_cpu(self, inputs): x_cpu = chainer.Variable(x) W_cpu = chainer.Variable(W) b_cpu = None if b is None else chainer.Variable(b) - y_cpu = F.deconvolution_2d( - x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, - outsize=self.outsize) + with chainer.using_config('use_ideep', 'never'): + y_cpu = F.deconvolution_2d( + x_cpu, W_cpu, b_cpu, stride=self.stride, pad=self.pad, + outsize=self.outsize) return y_cpu, def check_forward(self, inputs, backend_config):