Skip to content

Commit

Permalink
Fix unit tests of softmax_cross_entropy for soft target
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Oct 9, 2019
1 parent f847214 commit 620b55d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
3 changes: 1 addition & 2 deletions chainer/functions/loss/softmax_cross_entropy.py
Expand Up @@ -447,8 +447,7 @@ def _double_backward_softmax_cross_entropy(x, t, normalize, class_weight,
def softmax_cross_entropy(
x, t, normalize=True, cache_score=True, class_weight=None,
ignore_label=-1, reduce='mean', enable_double_backprop=False,
soft_target_loss='kl-divergence'):
# soft_target_loss='cross-entropy'):
soft_target_loss='cross-entropy'):
"""Computes cross entropy loss for pre-softmax activations.
Args:
Expand Down
Expand Up @@ -616,9 +616,9 @@ def check_forward(self, xp):
'shape': [(3,), (3, 2), (3, 2, 2)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'reduce': ['mean', 'no'],
'soft_target_loss': ['cross-entropy', 'kl-divergence'],
'soft_target_loss': ['kl-divergence'],
})))
class TestSoftTargetExpectNearZero(BaseSoftTarget, unittest.TestCase):
class TestSoftTargetKLDivergence(BaseSoftTarget, unittest.TestCase):

def setUp(self):
BaseSoftTarget.setUp(self)
Expand All @@ -637,4 +637,30 @@ def check_forward(self, xp):
**self.check_forward_options)


@testing.parameterize(*(testing.product({
'nb': [1, 2, 4],
'shape': [(3,), (3, 2), (3, 2, 2)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'reduce': ['mean', 'no'],
'soft_target_loss': ['cross-entropy'],
})))
class TestSoftTargetCrossEntropy(BaseSoftTarget, unittest.TestCase):

def setUp(self):
BaseSoftTarget.setUp(self)
self.t = functions.softmax(self.x).array
self.expect = numpy.sum(-self.t * functions.log_softmax(self.x).array,
axis=1)
if self.reduce == 'mean':
self.expect = numpy.average(self.expect)

def check_forward(self, xp):
x = xp.asarray(self.x)
t = xp.asarray(self.t)
loss = functions.softmax_cross_entropy(
x, t, reduce=self.reduce, soft_target_loss=self.soft_target_loss)
testing.assert_allclose(loss.data, self.expect,
**self.check_forward_options)


testing.run_module(__name__, __file__)

0 comments on commit 620b55d

Please sign in to comment.