Skip to content

Commit

Permalink
Add argument soft_target_loss to allow users to opt loss calculation …
Browse files Browse the repository at this point in the history
…method
  • Loading branch information
anaruse committed Jul 9, 2019
1 parent e60dd3e commit 1ab2632
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions chainer/functions/loss/softmax_cross_entropy.py
Expand Up @@ -72,14 +72,16 @@ class SoftmaxCrossEntropy(function_node.FunctionNode):
eps = 1e-7

def __init__(self, normalize=True, cache_score=True, class_weight=None,
ignore_label=-1, reduce='mean'):
ignore_label=-1, reduce='mean',
soft_target_loss='cross-entropy'):
self.normalize = normalize
self.cache_score = cache_score
_check_class_weight_option(class_weight)
self.class_weight = class_weight
self.ignore_label = ignore_label
_check_reduce_option(reduce)
self.reduce = reduce
self.soft_target_loss = soft_target_loss

def check_type_forward(self, in_types):
type_check._argname(in_types, ('x', 't'))
Expand Down Expand Up @@ -245,13 +247,16 @@ def forward_gpu(self, inputs):
return ret,

def _soft_target_loss(self, xp, x, t, log_y):
kl_d = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
if self.soft_target_loss == 'kl-divergence':
ret = xp.sum(t * (xp.log(t + self.eps) - log_y), axis=1)
else:
ret = -xp.sum(t * log_y), axis=1)
if self.reduce == 'mean':
self._coeff = 1.0 / (x.size / x.shape[1])
kl_d = kl_d.sum(keepdims=True) * self._coeff
return kl_d.reshape(()),
ret = ret.sum(keepdims=True) * self._coeff
return ret.reshape(()),
else:
return kl_d,
return ret,

def backward(self, input_indexes, grad_outputs):
func_grad = _SoftmaxCrossEntropyGrad_NoDoubleBackprop(
Expand Down Expand Up @@ -441,7 +446,8 @@ def _double_backward_softmax_cross_entropy(x, t, normalize, class_weight,

def softmax_cross_entropy(
x, t, normalize=True, cache_score=True, class_weight=None,
ignore_label=-1, reduce='mean', enable_double_backprop=False):
ignore_label=-1, reduce='mean', enable_double_backprop=False,
soft_target_loss='cross-entropy'):
"""Computes cross entropy loss for pre-softmax activations.
Args:
Expand All @@ -460,8 +466,8 @@ def softmax_cross_entropy(
When the dtype is float, this function treats ``t`` as an array
holding probability distribution of labels, in other words, soft
targets. In this case, the shape of ``t`` must be the same as the
shape of ``x``. Note that the loss is calculated using KL
divergence, not cross entropy.
shape of ``x``. Note that the loss is calculated using cross
entropy or KL divergence.
normalize (bool): If ``True``, this function normalizes the cross
entropy loss across all instances. If ``False``, it only
normalizes along a batch size.
Expand Down Expand Up @@ -494,6 +500,10 @@ def softmax_cross_entropy(
This function use the single-backprop version because we expect
it is faster. So, if you need second or higher derivatives,
you need to turn it on explicitly.
soft_target_loss (str): A string that determines what type of
method is used to calculate soft target loss. If
``'cross-entropy'`` and ``'kl-divergence'``, cross-entropy and
KL divergence are used for loss calculation.
Returns:
~chainer.Variable: A variable holding a scalar array of the cross
Expand Down Expand Up @@ -527,6 +537,10 @@ def softmax_cross_entropy(
is_chainerx = (
chainerx.is_available() and backend.get_array_module(x) is chainerx)

if soft_target_loss not in ('cross-entropy', 'kl-divergence'):
raise ValueError('soft_target_loss must be \'cross-entropy\' or '
'\'kl-divergence\'.')

if is_chainerx or not enable_double_backprop:
# Optimized implementation.
# For non-ChainerX, forward and backward are supported but
Expand All @@ -535,7 +549,8 @@ def softmax_cross_entropy(
# configuration of inputs and parameters, which is tested with
# `SoftmaxCrossEntropy._is_chainerx_supported()`.
func = SoftmaxCrossEntropy(
normalize, cache_score, class_weight, ignore_label, reduce)
normalize, cache_score, class_weight, ignore_label, reduce,
soft_target_loss)

if not is_chainerx or func._is_chainerx_supported((x, t)):
loss, = func.apply((x, t))
Expand Down

0 comments on commit 1ab2632

Please sign in to comment.