Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement chainerx::SoftmaxCrossEntropy and chainerx.softmax_cross_entropy #8250

Merged
merged 13 commits into from Oct 23, 2019
Merged
25 changes: 11 additions & 14 deletions chainer/functions/loss/softmax_cross_entropy.py
Expand Up @@ -109,22 +109,19 @@ def _is_chainerx_supported(self, input_arrays):
return True

def forward_chainerx(self, inputs):
x, t = inputs
n_classes = x.shape[1]
score = chainerx.log_softmax(x, axis=1)
mask = (t[:, chainerx.newaxis] == chainerx.arange(
n_classes, dtype=t.dtype, device=x.device)).astype(score.dtype)
if self.reduce == 'mean' and self.normalize:
x, t = inputs
n_classes = x.shape[1]
score = chainerx.log_softmax(x, axis=1)
mask = (t[:, chainerx.newaxis] == chainerx.arange(
n_classes, dtype=t.dtype, device=x.device)).astype(score.dtype)
y = (score * mask).sum() * (-1 / mask.sum())
return y,

neg_log_p = -score * mask
x, t = inputs
y = chainerx.softmax_cross_entropy(x, t)
if self.reduce == 'mean':
if self.normalize:
count = mask.sum()
else:
count = x.shape[0]
y = neg_log_p.sum() * (1 / count)
else:
y = neg_log_p.sum(axis=1)

return y.mean(),
return y,

def forward_cpu(self, inputs):
Expand Down
4 changes: 4 additions & 0 deletions chainerx/__init__.pyi
Expand Up @@ -933,6 +933,10 @@ def softmax(
x: ndarray,
axis: tp.Optional[tp.Union[int, tp.List[int]]]=None) -> ndarray: ...

def softmax_cross_entropy(
x1: ndarray,
x2: ndarray) -> ndarray: ...

def softplus(x: ndarray, beta: double=1.0) -> ndarray: ...

def split(
Expand Down
21 changes: 21 additions & 0 deletions chainerx/_docs/routines.py
Expand Up @@ -1223,6 +1223,27 @@ def _docs_loss():
Returns:
:class:`~chainerx.ndarray`: An array of the cross entropy.

Note:
During backpropagation, this function propagates the gradient of the output
array to the input array ``x1`` only.
""")

_docs.set_doc(
chainerx.softmax_cross_entropy,
"""softmax_cross_entropy(x1, x2)

Element-wise cross entropy loss for pre-softmax activations.

Args:
x1 (~chainerx.ndarray): An array whose element indicates unnormalized log
probability: the first axis of the array represents the number of
samples, and the second axis represents the number of classes.
x2 (~chainerx.ndarray): A signed integer vector of ground truth labels. If
``x2[i] == -1``, corresponding ``x1[i]`` is ignored.

Returns:
:class:`~chainerx.ndarray`: An array of the cross entropy.

Note:
During backpropagation, this function propagates the gradient of the output
array to the input array ``x1`` only.
Expand Down
4 changes: 4 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -1305,6 +1305,10 @@ void InitChainerxLoss(pybind11::module& m) {
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(SigmoidCrossEntropy(Array{x1}, Array{x2})); },
"x1"_a,
"x2"_a);
m.def("softmax_cross_entropy",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(SoftmaxCrossEntropy(Array{x1}, Array{x2})); },
"x1"_a,
"x2"_a);
m.def("hinge",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2, double norm) { return MoveArrayBody(Hinge(Array{x1}, Array{x2}, norm)); },
"x1"_a,
Expand Down
16 changes: 16 additions & 0 deletions chainerx_cc/chainerx/routines/loss.cc
Expand Up @@ -8,6 +8,7 @@
#include "chainerx/routines/logic.h"
#include "chainerx/routines/manipulation.h"
#include "chainerx/routines/misc.h"
#include "chainerx/routines/reduction.h"
#include "chainerx/scalar.h"

namespace chainerx {
Expand All @@ -33,6 +34,21 @@ Array SigmoidCrossEntropy(const Array& x1, const Array& x2) {
return -(ignore_mask * (x1 * (x2 - (GreaterEqual(x1, ZerosLike(x1, x1.device()))).AsType(x1.dtype())) - Log1p(Exp(-Absolute(x1)))));
}

Array SoftmaxCrossEntropy(const Array& x1, const Array& x2) {
if (x1.ndim() != 2) {
throw DimensionError{"Input array must be 2 dimensional."};
}
if (x2.ndim() != 1) {
throw DimensionError{"Target array must be 1 dimensional."};
}
if (x1.shape()[0] != x2.shape()[0]) {
throw DimensionError{"x1.shape[0] must be equal to x2.shape[0]"};
}
Array score = LogSoftmax(x1, 1);
Array mask = (x2.At({Slice{}, NewAxis{}}) == Arange(score.shape()[1], x2.dtype(), x1.device())).AsType(score.dtype());
return -(score * mask).Sum({1});
}

Array Hinge(const Array& x, const Array& t, double norm) {
if (x.ndim() != 2) {
throw DimensionError{"Input array must be 2 dimensional."};
Expand Down
2 changes: 2 additions & 0 deletions chainerx_cc/chainerx/routines/loss.h
Expand Up @@ -15,6 +15,8 @@ Array HuberLoss(const Array& x1, const Array& x2, Scalar delta);

Array SigmoidCrossEntropy(const Array& x1, const Array& x2);

Array SoftmaxCrossEntropy(const Array& x1, const Array& x2);

Array Hinge(const Array& x, const Array& t, double norm = 1.0);

} // namespace chainerx
1 change: 1 addition & 0 deletions docs/source/chainerx/reference/routines.rst
Expand Up @@ -165,6 +165,7 @@ Loss functions
chainerx.huber_loss
chainerx.gaussian_kl_divergence
chainerx.sigmoid_cross_entropy
chainerx.softmax_cross_entropy

Mathematical functions
----------------------
Expand Down
36 changes: 36 additions & 0 deletions tests/chainerx_tests/unit_tests/routines_tests/test_loss.py
Expand Up @@ -172,6 +172,42 @@ def forward_chainer(self, inputs):
return out,


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'x_dtype': chainerx.testing.float_dtypes,
't_dtype': chainerx.testing.signed_integral_dtypes,
})
))
class TestSoftmaxCrossEntropy(op_utils.ChainerOpTest):

def setup(self):
self.shape = (2, 2)

t_shape = self.shape[0],
t = numpy.random.randint(0, self.shape[1], t_shape)
self.t = t.astype(self.t_dtype)

if self.x_dtype == 'float16':
self.check_forward_options.update({'rtol': 5e-3, 'atol': 5e-4})
self.check_backward_options.update({'rtol': 5e-3, 'atol': 5e-4})

def generate_inputs(self):
x = numpy.random.normal(loc=0, scale=1.0, size=self.shape)
return x.astype(self.x_dtype),

def forward_chainerx(self, inputs):
x, = inputs
t = self.backend_config.get_array(self.t)
out = chainerx.softmax_cross_entropy(x, t)
return out,

def forward_chainer(self, inputs):
x, = inputs
out = F.softmax_cross_entropy(x, self.t, reduce='no')
return out,


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
Expand Down