Skip to content

Commit

Permalink
Implement python binding for chainerx::SoftmaxCrossEntropy
Browse files Browse the repository at this point in the history
  • Loading branch information
takagi committed Oct 7, 2019
1 parent 94e1d5a commit 6138818
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -1242,6 +1242,25 @@ void InitChainerxLoss(pybind11::module& m) {
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(SigmoidCrossEntropy(Array{x1}, Array{x2})); },
"x1"_a,
"x2"_a);
m.def("softmax_cross_entropy",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2, const std::string& reduce) {
Array x1_array{x1};
Array x2_array{x2};

SoftmaxCrossEntropyReduceMode reduce_mode{};
if (reduce == "mean") {
reduce_mode = SoftmaxCrossEntropyReduceMode::kMean;
} else if (reduce == "no") {
reduce_mode = SoftmaxCrossEntropyReduceMode::kNo;
} else {
throw py::value_error{"reduce_mode must be either of 'mean' or 'no'"};
}

return MoveArrayBody(SoftmaxCrossEntropy(x1_array, x2_array, reduce_mode));
},
"x1"_a,
"x2"_a,
"reduce"_a = "mean");
m.def("hinge",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2, double norm) { return MoveArrayBody(Hinge(Array{x1}, Array{x2}, norm)); },
"x1"_a,
Expand Down

0 comments on commit 6138818

Please sign in to comment.