Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add mixed dtypes tests
  • Loading branch information
asi1024 committed Sep 19, 2019
1 parent 5574933 commit 368681a
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions tests/chainerx_tests/unit_tests/routines_tests/test_loss.py
Expand Up @@ -3,6 +3,8 @@
import chainer
from chainer import functions as F
import chainerx

from chainerx_tests import dtype_utils
from chainerx_tests import op_utils


Expand All @@ -14,28 +16,44 @@
]


_in_out_loss_dtypes = dtype_utils._permutate_dtype_mapping([
(('float16', 'float16'), 'float16'),
(('float32', 'float32'), 'float32'),
(('float64', 'float64'), 'float64'),
(('float32', 'float16'), 'float32'),
(('float64', 'float16'), 'float64'),
(('float64', 'float32'), 'float64'),
])


class LossBase(op_utils.ChainerOpTest):

def setup(self):
super().setup()
if self.in_dtype == 'float16':
in_dtype1, in_dtype2 = self.in_dtypes
if in_dtype1 == 'float16' or in_dtype2 == 'float16':
self.check_forward_options.update({'rtol': 5e-3, 'atol': 5e-3})
self.check_backward_options.update({'rtol': 1e-2, 'atol': 5e-3})
self.check_double_backward_options.update(
{'rtol': 1e-2, 'atol': 3e-1})

def generate_inputs(self):
in_dtype1, in_dtype2 = self.in_dtypes
y = numpy.random.normal(loc=0, scale=1.0, size=self.shape)
targ = numpy.random.normal(loc=0, scale=1.0, size=self.shape) + \
numpy.random.normal(loc=0, scale=0.5, size=self.shape)
return y.astype(self.in_dtype), targ.astype(self.in_dtype)
return y.astype(in_dtype1), targ.astype(in_dtype2)

def forward_chainerx(self, inputs):
out, = self.forward_xp(inputs, chainerx)
return out,

def forward_chainer(self, inputs):
return self.forward_xp(inputs, F)
dtype = numpy.result_type(*inputs)
inputs = [x.astype(dtype) for x in inputs]
output, = self.forward_xp(inputs, F)
output.array = output.array.astype(self.out_dtype)
return output,

def forward_xp(self, inputs, xp):
raise NotImplementedError(
Expand All @@ -46,7 +64,7 @@ def forward_xp(self, inputs, xp):
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': _loss_shapes,
'in_dtype': chainerx.testing.float_dtypes,
'in_dtypes,out_dtype': _in_out_loss_dtypes,
})
))
class TestSquaredError(LossBase):
Expand All @@ -60,7 +78,7 @@ def forward_xp(self, inputs, xp):
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': _loss_shapes,
'in_dtype': chainerx.testing.float_dtypes,
'in_dtypes,out_dtype': _in_out_loss_dtypes,
})
))
class TestAbsoluteError(LossBase):
Expand All @@ -77,7 +95,7 @@ def forward_xp(self, inputs, xp):
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': _loss_shapes,
'in_dtype': chainerx.testing.float_dtypes,
'in_dtypes,out_dtype': _in_out_loss_dtypes,
})
))
class TestGaussianKLDivergence(LossBase):
Expand All @@ -95,7 +113,7 @@ def forward_xp(self, inputs, xp):
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': _loss_shapes,
'in_dtype': chainerx.testing.float_dtypes,
'in_dtypes,out_dtype': _in_out_loss_dtypes,
'delta': [0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5],
})
))
Expand All @@ -119,27 +137,36 @@ def forward_xp(self, inputs, xp):
@chainer.testing.parameterize(*(
chainer.testing.product({
'shape': _loss_shapes,
'in_dtype': chainerx.testing.float_dtypes,
'x_dtype': chainerx.testing.float_dtypes,
't_dtype': ['int8', 'int16', 'int32', 'int64'],
})
))
class TestSigmoidCrossEntropy(LossBase):
class TestSigmoidCrossEntropy(op_utils.ChainerOpTest):

def setup(self):
if self.x_dtype == 'float16':
self.check_forward_options.update({'rtol': 5e-3, 'atol': 5e-3})
self.check_backward_options.update({'rtol': 1e-2, 'atol': 5e-3})
self.check_double_backward_options.update(
{'rtol': 1e-2, 'atol': 3e-1})

def generate_inputs(self):
x = numpy.random.normal(loc=0, scale=1.0, size=self.shape)
targ = numpy.random.normal(loc=0, scale=1.0, size=self.shape) + \
numpy.random.normal(loc=0, scale=0.5, size=self.shape)
self.t = targ.astype(self.t_dtype)
return x.astype(self.in_dtype),
return x.astype(self.x_dtype),

def forward_xp(self, inputs, xp):
def forward_chainerx(self, inputs):
x, = inputs
# TODO(aksub99): Improve implementation to avoid non-differentiability
# wrt targets
if xp is chainerx:
t = self.backend_config.get_array(self.t)
out = xp.sigmoid_cross_entropy(x, t)
else:
t = self.t
out = xp.sigmoid_cross_entropy(x, t, normalize=False, reduce='no')
t = self.backend_config.get_array(self.t)
out = chainerx.sigmoid_cross_entropy(x, t)
return out,

def forward_chainer(self, inputs):
x, = inputs
t = self.t
out = F.sigmoid_cross_entropy(x, t, normalize=False, reduce='no')
return out,

0 comments on commit 368681a

Please sign in to comment.