diff --git a/chainer/functions/loss/softmax_cross_entropy.py b/chainer/functions/loss/softmax_cross_entropy.py index ff04db0d3ebc..c53c9ecdca86 100644 --- a/chainer/functions/loss/softmax_cross_entropy.py +++ b/chainer/functions/loss/softmax_cross_entropy.py @@ -72,7 +72,8 @@ class SoftmaxCrossEntropy(function_node.FunctionNode): eps = 1e-7 def __init__(self, normalize=True, cache_score=True, class_weight=None, - ignore_label=-1, reduce='mean'): + ignore_label=-1, reduce='mean', + soft_target_loss='cross-entropy'): self.normalize = normalize self.cache_score = cache_score _check_class_weight_option(class_weight) @@ -80,6 +81,7 @@ def __init__(self, normalize=True, cache_score=True, class_weight=None, self.ignore_label = ignore_label _check_reduce_option(reduce) self.reduce = reduce + self.soft_target_loss = soft_target_loss def check_type_forward(self, in_types): type_check._argname(in_types, ('x', 't')) @@ -245,13 +247,16 @@ def forward_gpu(self, inputs): return ret, def _soft_target_loss(self, xp, x, t, log_y): - kl_d = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1) + if self.soft_target_loss == 'kl-divergence': + ret = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1) + else: + ret = -xp.sum(t * log_y), axis=1) if self.reduce == 'mean': self._coeff = 1.0 / (x.size / x.shape[1]) - kl_d = kl_d.sum(keepdims=True) * self._coeff - return kl_d.reshape(()), + ret = ret.sum(keepdims=True) * self._coeff + return ret.reshape(()), else: - return kl_d, + return ret, def backward(self, input_indexes, grad_outputs): func_grad = _SoftmaxCrossEntropyGrad_NoDoubleBackprop( @@ -441,7 +446,8 @@ def _double_backward_softmax_cross_entropy(x, t, normalize, class_weight, def softmax_cross_entropy( x, t, normalize=True, cache_score=True, class_weight=None, - ignore_label=-1, reduce='mean', enable_double_backprop=False): + ignore_label=-1, reduce='mean', enable_double_backprop=False, + soft_target_loss='cross-entropy'): """Computes cross entropy loss for pre-softmax activations. Args: @@ -460,8 +466,8 @@ def softmax_cross_entropy( When the dtype is float, this function treats ``t`` as an array holding probability distribution of labels, in other words, soft targets. In this case, the shape of ``t`` must be the same as the - shape of ``x``. Note that the loss is calculated using KL - divergence, not cross entropy. + shape of ``x``. Note that the loss is calculated using cross + entropy or KL divergence. normalize (bool): If ``True``, this function normalizes the cross entropy loss across all instances. If ``False``, it only normalizes along a batch size. @@ -494,6 +500,10 @@ def softmax_cross_entropy( This function use the single-backprop version because we expect it is faster. So, if you need second or higher derivatives, you need to turn it on explicitly. + soft_target_loss (str): A string that determines what type of + method is used to calculate soft target loss. If + ``'cross-entropy'`` and ``'kl-divergence'``, cross-entropy and + KL divergence are used for loss calculation. Returns: ~chainer.Variable: A variable holding a scalar array of the cross @@ -527,6 +537,10 @@ def softmax_cross_entropy( is_chainerx = ( chainerx.is_available() and backend.get_array_module(x) is chainerx) + if soft_target_loss not in ('cross-entropy', 'kl-divergence'): + raise ValueError('soft_target_loss must be \'cross-entropy\' or ' + '\'kl-divergence\'.') + if is_chainerx or not enable_double_backprop: # Optimized implementation. # For non-ChainerX, forward and backward are supported but @@ -535,7 +549,8 @@ def softmax_cross_entropy( # configuration of inputs and parameters, which is tested with # `SoftmaxCrossEntropy._is_chainerx_supported()`. func = SoftmaxCrossEntropy( - normalize, cache_score, class_weight, ignore_label, reduce) + normalize, cache_score, class_weight, ignore_label, reduce, + soft_target_loss) if not is_chainerx or func._is_chainerx_supported((x, t)): loss, = func.apply((x, t))