Skip to content

Commit

Permalink
Merge 7ba05e7 into 11c237e
Browse files Browse the repository at this point in the history
  • Loading branch information
okuta committed Apr 7, 2018
2 parents 11c237e + 7ba05e7 commit 2219472
Showing 1 changed file with 21 additions and 13 deletions.
Expand Up @@ -571,19 +571,13 @@ def test_double_backward_gpu(self):


@testing.parameterize(*(testing.product({
'shape': [None, (2, 3), (2, 3, 2), (2, 3, 2, 2)],
'normalize': [True, False],
'ignore_index': [None, (slice(None),), (0,), (0, 1), (0, 1, 0)],
'dtype': [numpy.float32],
'weight_apply': [False, True],
'use_cudnn': ['always', 'auto', 'never'],
}) + testing.product({
'shape': [None, (2, 3), (2, 3, 2), (2, 3, 2, 2)],
'shape_ignore': [(None, None),
((2, 3), (slice(None),)),
((2, 3, 2), (0,)),
((2, 3, 2, 2), (0, 1, 0))],
'normalize': [True, False],
'ignore_index': [(0, 1)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'weight_apply': [False, True],
'use_cudnn': ['always', 'auto', 'never'],
})))
class TestForwardConsistency(unittest.TestCase):

Expand All @@ -592,6 +586,7 @@ class TestForwardConsistency(unittest.TestCase):
# agree.

def setUp(self):
self.shape, self.ignore_index = self.shape_ignore
if self.shape is None:
if self.dtype == numpy.float16:
self.x = numpy.array([[-5, 1]], dtype=self.dtype)
Expand Down Expand Up @@ -634,16 +629,29 @@ def f(enable_double_backprop):
loss_single = f(False)
loss_double = f(True)

check_forward_options = {'atol': 5e-4, 'rtol': 5e-3}
check_forward_options = {}
if self.dtype == numpy.float16:
check_forward_options = {'atol': 5e-4, 'rtol': 5e-3}
testing.assert_allclose(
loss_single, loss_double, **check_forward_options)

def test_consistency_cpu(self):
self.check_consistency(numpy)

@attr.gpu
def test_consistency_gpu(self):
self.check_consistency(cuda.cupy)
def test_consistency_gpu_always(self):
with chainer.using_config('use_cudnn', 'always'):
self.check_consistency(cuda.cupy)

@attr.gpu
def test_consistency_gpu_auto(self):
with chainer.using_config('use_cudnn', 'auto'):
self.check_consistency(cuda.cupy)

@attr.gpu
def test_consistency_gpu_never(self):
with chainer.using_config('use_cudnn', 'never'):
self.check_consistency(cuda.cupy)


testing.run_module(__name__, __file__)

0 comments on commit 2219472

Please sign in to comment.