diff --git a/chainer/functions/loss/softmax_cross_entropy.py b/chainer/functions/loss/softmax_cross_entropy.py index 1cb664c926fd..7e538e6040e4 100644 --- a/chainer/functions/loss/softmax_cross_entropy.py +++ b/chainer/functions/loss/softmax_cross_entropy.py @@ -68,9 +68,12 @@ class SoftmaxCrossEntropy(function_node.FunctionNode): # Coefficient of normalization. Only used if reduce='mean'. _coeff = None + soft_target = False + eps = 1e-7 def __init__(self, normalize=True, cache_score=True, class_weight=None, - ignore_label=-1, reduce='mean'): + ignore_label=-1, reduce='mean', + soft_target_loss='cross-entropy'): self.normalize = normalize self.cache_score = cache_score _check_class_weight_option(class_weight) @@ -78,19 +81,27 @@ def __init__(self, normalize=True, cache_score=True, class_weight=None, self.ignore_label = ignore_label _check_reduce_option(reduce) self.reduce = reduce + self.soft_target_loss = soft_target_loss def check_type_forward(self, in_types): type_check._argname(in_types, ('x', 't')) x_type, t_type = in_types - type_check.expect( - x_type.dtype.kind == 'f', - t_type.dtype.kind == 'i', - t_type.ndim == x_type.ndim - 1, - - x_type.shape[0] == t_type.shape[0], - x_type.shape[2:] == t_type.shape[1:], - ) + if t_type.dtype.kind == 'i': + type_check.expect( + x_type.dtype.kind == 'f', + t_type.dtype.kind == 'i', + t_type.ndim == x_type.ndim - 1, + x_type.shape[0] == t_type.shape[0], + x_type.shape[2:] == t_type.shape[1:], + ) + else: + # assume t is soft_target + type_check.expect( + x_type.dtype.kind == 'f', + t_type.dtype.kind == 'f', + x_type.shape == t_type.shape, + ) def _is_chainerx_supported(self, input_arrays): # Determines if the specified configuration of inputs and parameters @@ -127,12 +138,18 @@ def forward_cpu(self, inputs): self.retain_inputs((0, 1)) x, t = inputs - if chainer.is_debug(): + if x.ndim == t.ndim and x.shape == t.shape: + self.soft_target = True + if chainer.is_debug() and not self.soft_target: _check_input_values(x, t, self.ignore_label) log_y = log_softmax._log_softmax(x) if self.cache_score: self.y = numpy.exp(log_y) + + if self.soft_target: + return self._soft_target_loss(numpy, x, t, log_y) + if class_weight is not None: shape = [1 if d != 1 else -1 for d in six.moves.range(x.ndim)] log_y *= _broadcast_to(class_weight.reshape(shape), x.shape) @@ -165,9 +182,11 @@ def forward_gpu(self, inputs): class_weight = backend.from_chx(self.class_weight) self.retain_inputs((0, 1)) - cupy = cuda.cupy x, t = inputs - if chainer.is_debug(): + if x.ndim == t.ndim and x.shape == t.shape: + self.soft_target = True + cupy = cuda.cupy + if chainer.is_debug() and not self.soft_target: _check_input_values(x, t, self.ignore_label) if x.size == 0: @@ -181,6 +200,10 @@ def forward_gpu(self, inputs): log_y = log_softmax._log_softmax(x) if self.cache_score: self.y = cupy.exp(log_y) + + if self.soft_target: + return self._soft_target_loss(cupy, x, t, log_y) + if class_weight is not None: shape = [1 if d != 1 else -1 for d in six.moves.range(x.ndim)] log_y *= cupy.broadcast_to(class_weight.reshape(shape), x.shape) @@ -223,9 +246,22 @@ def forward_gpu(self, inputs): ret = ret.reshape(t.shape) return ret, + def _soft_target_loss(self, xp, x, t, log_y): + if self.soft_target_loss == 'kl-divergence': + ret = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1) + else: + ret = -xp.sum(t * log_y, axis=1) + if self.reduce == 'mean': + self._coeff = 1.0 / (x.size / x.shape[1]) + ret = ret.sum(keepdims=True) * self._coeff + return ret.reshape(()), + else: + return ret, + def backward(self, input_indexes, grad_outputs): func_grad = _SoftmaxCrossEntropyGrad_NoDoubleBackprop( - self.ignore_label, self.class_weight, self.y, self._coeff) + self.ignore_label, self.class_weight, self.y, self._coeff, + self.soft_target) inputs = self.get_retained_inputs() return func_grad.apply(inputs + grad_outputs) + (None,) @@ -233,11 +269,12 @@ def backward(self, input_indexes, grad_outputs): class _SoftmaxCrossEntropyGrad_NoDoubleBackprop(function_node.FunctionNode): # A backward implementation which does not support double-backprop. - def __init__(self, ignore_label, class_weight, y, coeff): + def __init__(self, ignore_label, class_weight, y, coeff, soft_target): self.ignore_label = ignore_label self.class_weight = class_weight self.y = y self.coeff = coeff + self.soft_target = soft_target def forward_cpu(self, inputs_and_grad_outputs): x, t, gloss = inputs_and_grad_outputs @@ -250,7 +287,9 @@ def forward_cpu(self, inputs_and_grad_outputs): numpy.exp(y, out=y) t_valid = t != self.ignore_label t = t * t_valid - if y.ndim == 2: + if self.soft_target: + gx = y - t + elif y.ndim == 2: gx = y gx[numpy.arange(len(t)), t] -= 1 if self.class_weight is not None: @@ -302,7 +341,9 @@ def forward_gpu(self, inputs_and_grad_outputs): gloss = gloss[:, None, ...] coeff = cupy.array(1, dtype=gloss.dtype) # dtype does not matter - if self.class_weight is None: + if self.soft_target: + gx = gloss * coeff * (y - t) + elif self.class_weight is None: gx = cuda.elementwise( 'T y, S t, T gloss, U coeff, S n_channel, S n_unit, ' 'S ignore_label', @@ -405,7 +446,8 @@ def _double_backward_softmax_cross_entropy(x, t, normalize, class_weight, def softmax_cross_entropy( x, t, normalize=True, cache_score=True, class_weight=None, - ignore_label=-1, reduce='mean', enable_double_backprop=False): + ignore_label=-1, reduce='mean', enable_double_backprop=False, + soft_target_loss='cross-entropy'): """Computes cross entropy loss for pre-softmax activations. Args: @@ -421,6 +463,11 @@ def softmax_cross_entropy( Variable holding a signed integer vector of ground truth labels. If ``t[i] == ignore_label``, corresponding ``x[i]`` is ignored. + When the dtype is float, this function treats ``t`` as an array + holding probability distribution of labels, in other words, soft + targets. In this case, the shape of ``t`` must be the same as the + shape of ``x``. Note that the loss is calculated using cross + entropy or KL divergence. normalize (bool): If ``True``, this function normalizes the cross entropy loss across all instances. If ``False``, it only normalizes along a batch size. @@ -453,6 +500,10 @@ def softmax_cross_entropy( This function use the single-backprop version because we expect it is faster. So, if you need second or higher derivatives, you need to turn it on explicitly. + soft_target_loss (str): A string that determines what type of + method is used to calculate soft target loss. If + ``'cross-entropy'`` and ``'kl-divergence'``, cross-entropy and + KL divergence are used for loss calculation. Returns: ~chainer.Variable: A variable holding a scalar array of the cross @@ -486,6 +537,10 @@ def softmax_cross_entropy( is_chainerx = ( chainerx.is_available() and backend.get_array_module(x) is chainerx) + if soft_target_loss not in ('cross-entropy', 'kl-divergence'): + raise ValueError('soft_target_loss must be \'cross-entropy\' or ' + '\'kl-divergence\'.') + if is_chainerx or not enable_double_backprop: # Optimized implementation. # For non-ChainerX, forward and backward are supported but @@ -494,7 +549,8 @@ def softmax_cross_entropy( # configuration of inputs and parameters, which is tested with # `SoftmaxCrossEntropy._is_chainerx_supported()`. func = SoftmaxCrossEntropy( - normalize, cache_score, class_weight, ignore_label, reduce) + normalize, cache_score, class_weight, ignore_label, reduce, + soft_target_loss) if not is_chainerx or func._is_chainerx_supported((x, t)): loss, = func.apply((x, t)) diff --git a/tests/chainer_tests/functions_tests/loss_tests/test_softmax_cross_entropy.py b/tests/chainer_tests/functions_tests/loss_tests/test_softmax_cross_entropy.py index 5dd3d580f133..40cb222ec3bf 100644 --- a/tests/chainer_tests/functions_tests/loss_tests/test_softmax_cross_entropy.py +++ b/tests/chainer_tests/functions_tests/loss_tests/test_softmax_cross_entropy.py @@ -530,4 +530,137 @@ def test_consistency_gpu_never(self): self.check_consistency(cuda.cupy) +class BaseSoftTarget(object): + + def setUp(self): + x_shape = (self.nb,) + self.shape + self.x = numpy.random.uniform(-1, 1, x_shape).astype(self.dtype) + if self.reduce == 'mean': + self.gy = numpy.random.uniform(-1, 1, ()).astype(self.dtype) + else: + y_shape = (self.nb,) + self.shape[1:] + self.gy = numpy.random.uniform(-1, 1, y_shape).astype(self.dtype) + if self.dtype == numpy.float16: + self.check_forward_options = {'atol': 5e-3, 'rtol': 5e-2} + self.check_backward_options = {'atol': 5e-3, 'rtol': 5e-2} + else: + self.check_forward_options = {} + self.check_backward_options = {} + + def check_forward(self, xp): + raise NotImplementedError + + def test_forward_cpu(self): + self.check_forward(numpy) + + @attr.gpu + def test_forward_gpu(self): + self.check_forward(cuda.cupy) + + def check_backward(self, xp): + x = xp.asarray(self.x) + t = xp.asarray(self.t) + gy = None + if self.reduce == 'no': + gy = xp.asarray(self.gy) + + def f(x_, t_): + return functions.softmax_cross_entropy( + x_, t_, reduce=self.reduce) + + gradient_check.check_backward(f, (x, t), gy, dtype=numpy.float64, + no_grads=(False, True), + **self.check_backward_options) + + def test_backward_cpu(self): + self.check_backward(numpy) + + @attr.gpu + def test_backward_gpu(self): + self.check_backward(cuda.cupy) + + +@testing.parameterize(*(testing.product({ + 'nb': [1, 2, 4], + 'shape': [(3,), (3, 2), (3, 2, 2)], + 'dtype': [numpy.float16, numpy.float32, numpy.float64], + 'reduce': ['mean', 'no'], + 'soft_target_loss': ['cross-entropy', 'kl-divergence'], +}))) +class TestSoftTargetCompareToHard(BaseSoftTarget, unittest.TestCase): + + def setUp(self): + BaseSoftTarget.setUp(self) + t_hard_shape = (self.nb,) + self.shape[1:] + self.t_hard = numpy.random.randint( + 0, self.shape[0], t_hard_shape).astype(numpy.int32) + t = numpy.zeros(self.x.size).astype(self.dtype) + t = t.reshape(self.shape[0], -1) + t[[self.t_hard.ravel()], [range(t.shape[1])]] = 1.0 + t = t.reshape((self.shape[0], self.nb,) + self.shape[1:]) + self.t = t.swapaxes(0, 1) + + def check_forward(self, xp): + x = xp.asarray(self.x) + t = xp.asarray(self.t) + loss = functions.softmax_cross_entropy(x, t, reduce=self.reduce) + expect = functions.softmax_cross_entropy( + x, xp.asarray(self.t_hard), reduce=self.reduce, + soft_target_loss=self.soft_target_loss) + testing.assert_allclose(loss.data, expect.data, + **self.check_forward_options) + + +@testing.parameterize(*(testing.product({ + 'nb': [1, 2, 4], + 'shape': [(3,), (3, 2), (3, 2, 2)], + 'dtype': [numpy.float16, numpy.float32, numpy.float64], + 'reduce': ['mean', 'no'], + 'soft_target_loss': ['kl-divergence'], +}))) +class TestSoftTargetKLDivergence(BaseSoftTarget, unittest.TestCase): + + def setUp(self): + BaseSoftTarget.setUp(self) + self.t = functions.softmax(self.x).array + + def check_forward(self, xp): + x = xp.asarray(self.x) + t = xp.asarray(self.t) + loss = functions.softmax_cross_entropy( + x, t, reduce=self.reduce, soft_target_loss=self.soft_target_loss) + if self.reduce == 'mean': + expect = 0. + else: + expect = numpy.zeros(self.gy.shape, dtype=self.dtype) + testing.assert_allclose(loss.data, expect, + **self.check_forward_options) + + +@testing.parameterize(*(testing.product({ + 'nb': [1, 2, 4], + 'shape': [(3,), (3, 2), (3, 2, 2)], + 'dtype': [numpy.float16, numpy.float32, numpy.float64], + 'reduce': ['mean', 'no'], + 'soft_target_loss': ['cross-entropy'], +}))) +class TestSoftTargetCrossEntropy(BaseSoftTarget, unittest.TestCase): + + def setUp(self): + BaseSoftTarget.setUp(self) + self.t = functions.softmax(self.x).array + self.expect = numpy.sum(-self.t * functions.log_softmax(self.x).array, + axis=1) + if self.reduce == 'mean': + self.expect = numpy.average(self.expect) + + def check_forward(self, xp): + x = xp.asarray(self.x) + t = xp.asarray(self.t) + loss = functions.softmax_cross_entropy( + x, t, reduce=self.reduce, soft_target_loss=self.soft_target_loss) + testing.assert_allclose(loss.data, self.expect, + **self.check_forward_options) + + testing.run_module(__name__, __file__)