Skip to content

Commit

Permalink
Merge pull request #7351 from aksub99/Add_leaky_relu
Browse files Browse the repository at this point in the history
Add `chainerx.leaky_relu`
  • Loading branch information
hvy committed Jun 17, 2019
2 parents 316a63e + a10f610 commit 4b0daf5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
8 changes: 8 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -979,9 +979,17 @@ void InitChainerxPooling(pybind11::module& m) {
"pad_mode"_a = "ignore");
}

void InitChainerxActivation(pybind11::module& m) {
m.def("leaky_relu",
[](const ArrayBodyPtr& x, Scalar slope) { return MoveArrayBody(LeakyRelu(Array{x}, slope)); },
py::arg("x"),
py::arg("slope") = 0.2);
}

} // namespace

void InitChainerxRoutines(pybind11::module& m) {
InitChainerxActivation(m);
InitChainerxCreation(m);
InitChainerxIndexing(m);
InitChainerxLinalg(m);
Expand Down
8 changes: 8 additions & 0 deletions chainerx_cc/chainerx/routines/activation.cc
Expand Up @@ -17,6 +17,7 @@
#include "chainerx/routines/arithmetic.h"
#include "chainerx/routines/creation.h"
#include "chainerx/routines/explog.h"
#include "chainerx/routines/indexing.h"
#include "chainerx/routines/misc.h"
#include "chainerx/routines/type_util.h"
#include "chainerx/scalar.h"
Expand All @@ -36,4 +37,11 @@ Array Relu(const Array& x) {
return Maximum(0, x_cast);
}

Array LeakyRelu(const Array& x, Scalar slope) {
Dtype dtype = internal::GetMathResultDtype(x.dtype());
const Array& x_cast = x.dtype() == dtype ? x : x.AsType(dtype);
Array zero = ZerosLike(x_cast, x_cast.device());
return Where(x_cast >= zero, x_cast, slope * x_cast);
}

} // namespace chainerx
6 changes: 2 additions & 4 deletions chainerx_cc/chainerx/routines/activation.h
@@ -1,9 +1,5 @@
#pragma once

#include <cstdint>

#include <nonstd/optional.hpp>

#include "chainerx/array.h"
#include "chainerx/scalar.h"

Expand All @@ -13,4 +9,6 @@ Array Sigmoid(const Array& x);

Array Relu(const Array& x);

Array LeakyRelu(const Array& x, Scalar slope);

} // namespace chainerx
30 changes: 30 additions & 0 deletions tests/chainerx_tests/unit_tests/routines_tests/test_activation.py
Expand Up @@ -72,6 +72,36 @@ def forward_xp(self, inputs, xp):
]


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
# Special shapes
chainer.testing.product({
'shape': [(), (0,), (1,), (2, 0, 3), (1, 1, 1), (2, 3)],
'in_dtypes,out_dtype': _in_out_dtypes_math_functions,
'input': [-2, 2],
'contiguous': [None, 'C'],
})
# Special values
+ chainer.testing.product({
'shape': [(2, 3)],
'in_dtypes,out_dtype': _in_out_float_dtypes_math_functions,
'input': [0, float('inf'), -float('inf'), float('nan')],
'skip_backward_test': [True],
'skip_double_backward_test': [True],
})
))
class TestLeakyRelu(UnaryMathTestBase, op_utils.NumpyOpTest):

slope = 0.2
check_numpy_strides_compliance = False

def func(self, xp, a):
if xp is numpy:
expected = numpy.where(a >= 0, a, a * self.slope)
return expected
return xp.leaky_relu(a, self.slope)


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
# Special shapes
Expand Down

0 comments on commit 4b0daf5

Please sign in to comment.