diff --git a/chainerx_cc/chainerx/python/routines.cc b/chainerx_cc/chainerx/python/routines.cc index 8ecc620438bc..d277cdb2b7d6 100644 --- a/chainerx_cc/chainerx/python/routines.cc +++ b/chainerx_cc/chainerx/python/routines.cc @@ -979,9 +979,17 @@ void InitChainerxPooling(pybind11::module& m) { "pad_mode"_a = "ignore"); } +void InitChainerxActivation(pybind11::module& m) { + m.def("leaky_relu", + [](const ArrayBodyPtr& x, Scalar slope) { return MoveArrayBody(LeakyRelu(Array{x}, slope)); }, + py::arg("x"), + py::arg("slope") = 0.2); +} + } // namespace void InitChainerxRoutines(pybind11::module& m) { + InitChainerxActivation(m); InitChainerxCreation(m); InitChainerxIndexing(m); InitChainerxLinalg(m); diff --git a/chainerx_cc/chainerx/routines/activation.cc b/chainerx_cc/chainerx/routines/activation.cc index 5ec59b3abffe..41b27b48412c 100644 --- a/chainerx_cc/chainerx/routines/activation.cc +++ b/chainerx_cc/chainerx/routines/activation.cc @@ -17,6 +17,7 @@ #include "chainerx/routines/arithmetic.h" #include "chainerx/routines/creation.h" #include "chainerx/routines/explog.h" +#include "chainerx/routines/indexing.h" #include "chainerx/routines/misc.h" #include "chainerx/routines/type_util.h" #include "chainerx/scalar.h" @@ -36,4 +37,11 @@ Array Relu(const Array& x) { return Maximum(0, x_cast); } +Array LeakyRelu(const Array& x, Scalar slope) { + Dtype dtype = internal::GetMathResultDtype(x.dtype()); + const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype); + Array zero = ZerosLike(x_cast, x_cast.device()); + return Where(x_cast >= zero, x_cast, slope * x_cast); +} + } // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/activation.h b/chainerx_cc/chainerx/routines/activation.h index d1019294fc38..4024b493bc46 100644 --- a/chainerx_cc/chainerx/routines/activation.h +++ b/chainerx_cc/chainerx/routines/activation.h @@ -1,9 +1,5 @@ #pragma once -#include - -#include - #include "chainerx/array.h" #include "chainerx/scalar.h" @@ -13,4 +9,6 @@ Array Sigmoid(const Array& x); Array Relu(const Array& x); +Array LeakyRelu(const Array& x, Scalar slope); + } // namespace chainerx diff --git a/tests/chainerx_tests/unit_tests/routines_tests/test_activation.py b/tests/chainerx_tests/unit_tests/routines_tests/test_activation.py index b47214cd32f4..59062ac1c09c 100644 --- a/tests/chainerx_tests/unit_tests/routines_tests/test_activation.py +++ b/tests/chainerx_tests/unit_tests/routines_tests/test_activation.py @@ -72,6 +72,36 @@ def forward_xp(self, inputs, xp): ] +@op_utils.op_test(['native:0', 'cuda:0']) +@chainer.testing.parameterize(*( + # Special shapes + chainer.testing.product({ + 'shape': [(), (0,), (1,), (2, 0, 3), (1, 1, 1), (2, 3)], + 'in_dtypes,out_dtype': _in_out_dtypes_math_functions, + 'input': [-2, 2], + 'contiguous': [None, 'C'], + }) + # Special values + + chainer.testing.product({ + 'shape': [(2, 3)], + 'in_dtypes,out_dtype': _in_out_float_dtypes_math_functions, + 'input': [0, float('inf'), -float('inf'), float('nan')], + 'skip_backward_test': [True], + 'skip_double_backward_test': [True], + }) +)) +class TestLeakyRelu(UnaryMathTestBase, op_utils.NumpyOpTest): + + slope = 0.2 + check_numpy_strides_compliance = False + + def func(self, xp, a): + if xp is numpy: + expected = numpy.where(a >= 0, a, a * self.slope) + return expected + return xp.leaky_relu(a, self.slope) + + @op_utils.op_test(['native:0', 'cuda:0']) @chainer.testing.parameterize(*( # Special shapes