Skip to content

Commit

Permalink
Merge pull request #4227 from niboshi/refactor-test-split-axis
Browse files Browse the repository at this point in the history
Refactor test_split_axis
  • Loading branch information
kmaehashi committed Jan 25, 2018
2 parents 57cbe13 + fe0cc87 commit 0920bc2
Showing 1 changed file with 93 additions and 68 deletions.
161 changes: 93 additions & 68 deletions tests/chainer_tests/functions_tests/array_tests/test_split_axis.py
Expand Up @@ -6,7 +6,17 @@
from chainer import cuda
from chainer import functions
from chainer import testing
from chainer.testing import attr
from chainer.testing import backend


def inject_backend_tests(method_names):
decorator = backend.inject_backend_tests(
method_names,
# CPU tests
[{'use_cuda': False}]
# GPU tests
+ [{'use_cuda': True}])
return decorator


@testing.parameterize(*testing.product_dict(
Expand Down Expand Up @@ -39,107 +49,122 @@
{'dtype': numpy.float64},
],
))
@inject_backend_tests(['test_forward', 'test_backward'])
class TestSplitAxis(unittest.TestCase):

def setUp(self):
self.x = numpy.arange(
numpy.prod(self.shape), dtype=self.dtype).reshape(self.shape)
self.ys = [self.x[s] for s in self.slices]

def check_forward(self, x_data, ys_data, indices_or_sections, axis):
x = chainer.Variable(x_data)
ys = functions.split_axis(
x, indices_or_sections, axis, force_tuple=True)
for yd, y in zip(ys_data, ys):
self.assertEqual(y.data.dtype, self.dtype)
self.assertIsInstance(y.data.shape, tuple)
shape = self.shape
dtype = self.dtype

x = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape)
self.ys_expected = [x[s] for s in self.slices]
self.inputs = [x]

def check_forward(self, inputs, backend_config):
if backend_config.use_cuda:
inputs = cuda.to_gpu(inputs)

x, = inputs
x = chainer.Variable(x)

with backend_config:
ys = functions.split_axis(
x, self.ys_section, self.axis, force_tuple=True)

for yd, y in zip(self.ys_expected, ys):
assert y.data.dtype == self.dtype
assert isinstance(y.data.shape, tuple)
testing.assert_allclose(yd, y.data, atol=0, rtol=0)

def test_forward_cpu(self):
self.check_forward(self.x, self.ys, self.ys_section, self.axis)
def test_forward(self, backend_config):
self.check_forward(self.inputs, backend_config)

@attr.gpu
def test_forward_gpu(self):
self.check_forward(
cuda.to_gpu(self.x),
[cuda.to_gpu(y.copy()) for y in self.ys],
self.ys_section, axis=self.axis)
def check_backward(self, inputs, backend_config):
if backend_config.use_cuda:
inputs = cuda.to_gpu(inputs)

def check_backward(self, x_data, indices_or_sections, axis):
x = chainer.Variable(x_data)
ys = functions.split_axis(
x, indices_or_sections, axis, force_tuple=True)
for y in ys:
y.grad = y.data
ys[0].backward()
x, = inputs
x = chainer.Variable(x)

testing.assert_allclose(x.data, x.grad, atol=0, rtol=0)
with backend_config:
ys = functions.split_axis(
x, self.ys_section, self.axis, force_tuple=True)
for y in ys:
y.grad = y.data
ys[0].backward()

def test_backward_cpu(self):
self.check_backward(self.x, self.ys_section, axis=self.axis)
testing.assert_allclose(x.data, x.grad, atol=0, rtol=0)

@attr.gpu
def test_backward_gpu(self):
self.check_backward(
cuda.to_gpu(self.x), self.ys_section, axis=self.axis)
def test_backward(self, backend_config):
self.check_backward(self.inputs, backend_config)


@inject_backend_tests(['test_backward'])
class TestSplitAxisNone(unittest.TestCase):

def setUp(self):
self.x = numpy.array([1, 2], dtype=numpy.float32)
self.ys_section = [1]
self.axis = 0

def check_backward(self, x_data, indices_or_sections, axis):
x = chainer.Variable(x_data)
ys = functions.split_axis(x, indices_or_sections, axis)
# Only set ys[0]
ys[0].grad = ys[0].data
ys[0].backward()
self.inputs = [numpy.array([1, 2], dtype=numpy.float32)]

def check_backward(self, inputs, backend_config):
if backend_config.use_cuda:
inputs = cuda.to_gpu(inputs)

x, = inputs
x = chainer.Variable(x)

with backend_config:
ys = functions.split_axis(x, self.ys_section, self.axis)
# Only set ys[0]
ys[0].grad = ys[0].data
ys[0].backward()

gx = numpy.array([1, 0])
testing.assert_allclose(gx, x.grad, atol=0, rtol=0)

def test_backward_cpu(self):
self.check_backward(self.x, self.ys_section, axis=self.axis)

@attr.gpu
def test_backward_gpu(self):
self.check_backward(
cuda.to_gpu(self.x), self.ys_section, axis=self.axis)
def test_backward(self, backend_config):
self.check_backward(self.inputs, backend_config)


@inject_backend_tests(['test_forward_force_tuple', 'test_forward_single'])
class TestSplitAxisForceArray(unittest.TestCase):

def setUp(self):
self.x = numpy.arange(42, dtype=numpy.float32).reshape(2, 7, 3)
self.axis = 1
self.inputs = [numpy.arange(42, dtype=numpy.float32).reshape(2, 7, 3)]

def check_forward_force_tuple(self, inputs, backend_config):
if backend_config.use_cuda:
inputs = cuda.to_gpu(inputs)

x, = self.inputs
x = chainer.Variable(x)

with backend_config:
ys = functions.split_axis(x, 1, self.axis, force_tuple=True)

assert isinstance(ys, tuple)
assert len(ys) == 1

def check_forward_force_tuple(self, x_data, axis):
x = chainer.Variable(x_data)
ys = functions.split_axis(x, 1, axis, force_tuple=True)
self.assertIsInstance(ys, tuple)
self.assertEqual(len(ys), 1)
def test_forward_force_tuple(self, backend_config):
self.check_forward_force_tuple(self.inputs, backend_config)

def test_forward_force_tuple_cpu(self):
self.check_forward_force_tuple(self.x, self.axis)
def check_forward_single(self, inputs, backend_config):
if backend_config.use_cuda:
inputs = cuda.to_gpu(inputs)

@attr.gpu
def test_forward_force_tuple_gpu(self):
self.check_forward_force_tuple(cuda.to_gpu(self.x), axis=self.axis)
x, = self.inputs
x = chainer.Variable(x)

def check_forward_single(self, x_data, axis):
x = chainer.Variable(x_data)
ys = functions.split_axis(x, 1, axis, force_tuple=False)
self.assertIsInstance(ys, chainer.Variable)
with backend_config:
ys = functions.split_axis(x, 1, self.axis, force_tuple=False)

def test_forward_single_cpu(self):
self.check_forward_single(self.x, self.axis)
assert isinstance(ys, chainer.Variable)

@attr.gpu
def test_forward_single_gpu(self):
self.check_forward_single(cuda.to_gpu(self.x), axis=self.axis)
def test_forward_single(self, backend_config):
self.check_forward_single(self.inputs, backend_config)


class TestSplitAxisInvalidSections(unittest.TestCase):
Expand Down

0 comments on commit 0920bc2

Please sign in to comment.