Skip to content

Commit

Permalink
Add a parameter soft_target_loss to softmax_cross_entropy test
Browse files Browse the repository at this point in the history
  • Loading branch information
anaruse committed Jul 16, 2019
1 parent 4d50de1 commit f847214
Showing 1 changed file with 7 additions and 3 deletions.
Expand Up @@ -585,6 +585,7 @@ def test_backward_gpu(self):
'shape': [(3,), (3, 2), (3, 2, 2)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'reduce': ['mean', 'no'],
'soft_target_loss': ['cross-entropy', 'kl-divergence'],
})))
class TestSoftTargetCompareToHard(BaseSoftTarget, unittest.TestCase):

Expand All @@ -603,8 +604,9 @@ def check_forward(self, xp):
x = xp.asarray(self.x)
t = xp.asarray(self.t)
loss = functions.softmax_cross_entropy(x, t, reduce=self.reduce)
expect = functions.softmax_cross_entropy(x, xp.asarray(self.t_hard),
reduce=self.reduce)
expect = functions.softmax_cross_entropy(
x, xp.asarray(self.t_hard), reduce=self.reduce,
soft_target_loss=self.soft_target_loss)
testing.assert_allclose(loss.data, expect.data,
**self.check_forward_options)

Expand All @@ -614,6 +616,7 @@ def check_forward(self, xp):
'shape': [(3,), (3, 2), (3, 2, 2)],
'dtype': [numpy.float16, numpy.float32, numpy.float64],
'reduce': ['mean', 'no'],
'soft_target_loss': ['cross-entropy', 'kl-divergence'],
})))
class TestSoftTargetExpectNearZero(BaseSoftTarget, unittest.TestCase):

Expand All @@ -624,7 +627,8 @@ def setUp(self):
def check_forward(self, xp):
x = xp.asarray(self.x)
t = xp.asarray(self.t)
loss = functions.softmax_cross_entropy(x, t, reduce=self.reduce)
loss = functions.softmax_cross_entropy(
x, t, reduce=self.reduce, soft_target_loss=self.soft_target_loss)
if self.reduce == 'mean':
expect = 0.
else:
Expand Down

0 comments on commit f847214

Please sign in to comment.