Skip to content

Commit

Permalink
Merge pull request #6541 from aksub99/Implement_array_vs_array_min
Browse files Browse the repository at this point in the history
Implement array vs array functionality to chainerx.minimum
  • Loading branch information
asi1024 committed Apr 9, 2019
1 parent f0dd1cc commit 2d784df
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainerx_cc/chainerx/cuda/cuda_device.h
Expand Up @@ -123,6 +123,7 @@ class CudaDevice : public Device {
void IfLessElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override;

void IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override;
void IfGreaterElseAAAA(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) override;

void Tanh(const Array& x, const Array& out) override;

Expand Down
22 changes: 22 additions & 0 deletions chainerx_cc/chainerx/cuda/cuda_device/activation.cu
Expand Up @@ -63,6 +63,18 @@ struct IfGreaterElseASSAImpl {

} // namespace

namespace {

template <typename T>
struct IfGreaterElseAAAAImpl {
using CudaType = cuda_internal::DataType<T>;
__device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, CudaType pos, CudaType neg, CudaType& out) {
out = x1 > x2 ? pos : neg;
}
};

} // namespace

void CudaDevice::IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) {
CheckDevicesCompatible(x1, neg, out);
Dtype x_dtype = ResultType(x1, x2);
Expand All @@ -81,6 +93,16 @@ void CudaDevice::IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, const
});
}

void CudaDevice::IfGreaterElseAAAA(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) {
CheckDevicesCompatible(x1, x2, pos, neg, out);
CudaSetDeviceScope scope{index()};
VisitDtype(out.dtype(), [&](auto pt) {
using T = typename decltype(pt)::type;
using CudaType = cuda_internal::DataType<T>;
Elementwise<const T, const T, const T, const T, T>(IfGreaterElseAAAAImpl<T>{}, x1, x2, pos, neg, out);
});
}

namespace {

template <typename T>
Expand Down
1 change: 1 addition & 0 deletions chainerx_cc/chainerx/device.h
Expand Up @@ -183,6 +183,7 @@ class Device {
//
// Formally, it calculates: out = x1 > x2 ? pos : neg
virtual void IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) = 0;
virtual void IfGreaterElseAAAA(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) = 0;

virtual void Tanh(const Array& x, const Array& out) = 0;

Expand Down
1 change: 1 addition & 0 deletions chainerx_cc/chainerx/native/native_device.h
Expand Up @@ -85,6 +85,7 @@ class NativeDevice : public Device {
void IfLessElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override;

void IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, const Array& neg, const Array& out) override;
void IfGreaterElseAAAA(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) override;

void Tanh(const Array& x, const Array& out) override;

Expand Down
11 changes: 11 additions & 0 deletions chainerx_cc/chainerx/native/native_device/activation.cc
Expand Up @@ -52,6 +52,17 @@ void NativeDevice::IfGreaterElseASSA(const Array& x1, Scalar x2, Scalar pos, con
});
}

void NativeDevice::IfGreaterElseAAAA(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) {
CheckDevicesCompatible(x1, x2, pos, neg, out);
VisitDtype(out.dtype(), [&](auto pt) {
using T = typename decltype(pt)::type;
struct Impl {
void operator()(int64_t /*i*/, T x1, T x2, T pos, T neg, T& out) { out = x1 > x2 ? pos : neg; }
};
Elementwise<const T, const T, const T, const T, T>(Impl{}, x1, x2, pos, neg, out);
});
}

void NativeDevice::Tanh(const Array& x, const Array& out) {
CheckDevicesCompatible(x, out);
const Array& x_cast = x.dtype() == out.dtype() ? x : x.AsType(out.dtype());
Expand Down
4 changes: 4 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -578,6 +578,10 @@ void InitChainerxMath(pybind11::module& m) {
m.def("maximum", [](Scalar x1, const ArrayBodyPtr& x2) { return MoveArrayBody(Maximum(x1, Array{x2})); }, py::arg("x1"), py::arg("x2"));
m.def("minimum", [](const ArrayBodyPtr& x1, Scalar x2) { return MoveArrayBody(Minimum(Array{x1}, x2)); }, py::arg("x1"), py::arg("x2"));
m.def("minimum", [](Scalar x1, const ArrayBodyPtr& x2) { return MoveArrayBody(Minimum(x1, Array{x2})); }, py::arg("x1"), py::arg("x2"));
m.def("minimum",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(Minimum(Array{x1}, Array{x2})); },
py::arg("x1"),
py::arg("x2"));
m.def("exp", [](const ArrayBodyPtr& x) { return MoveArrayBody(Exp(Array{x})); }, py::arg("x"));
m.def("log", [](const ArrayBodyPtr& x) { return MoveArrayBody(Log(Array{x})); }, py::arg("x"));
m.def("logsumexp",
Expand Down
41 changes: 41 additions & 0 deletions chainerx_cc/chainerx/routines/math.cc
Expand Up @@ -592,6 +592,42 @@ Array IfGreaterElse(const Array& x1, Scalar x2, Scalar pos, const Array& neg) {

} // namespace

namespace {

void IfGreaterElseImpl(const Array& x1, const Array& x2, const Array& pos, const Array& neg, const Array& out) {
CheckEqual(x1.shape(), x2.shape());
Array mask = Greater(x1, x2);
Array not_mask = LogicalNot(mask);
{
NoBackpropModeScope scope{};
x1.device().IfGreaterElseAAAA(x1, x2, pos, neg, out);
}
{
BackwardBuilder bb{"if_greater_else", {pos, neg}, out};
if (BackwardBuilder::Target bt = bb.CreateTarget(0)) {
bt.Define([mask = std::move(mask)](BackwardContext& bctx) {
const Array& gout = *bctx.output_grad();
bctx.input_grad() = gout * mask;
});
}
if (BackwardBuilder::Target bt = bb.CreateTarget(1)) {
bt.Define([not_mask = std::move(not_mask)](BackwardContext& bctx) {
const Array& gout = *bctx.output_grad();
bctx.input_grad() = gout * not_mask;
});
}
bb.Finalize();
}
}

} // namespace

namespace {

void MinimumImpl(const Array& x1, const Array& x2, const Array& out) { IfGreaterElseImpl(x1, x2, x2, x1, out); }

} // namespace

Array Maximum(const Array& x1, Scalar x2) {
// TODO(niboshi): IfLessElse redundantly casts x1 twice.
return IfLessElse(x1, x2, x2, x1); // x1 < x2 ? x2 : x1
Expand All @@ -606,6 +642,11 @@ Array Minimum(const Array& x1, Scalar x2) {

Array Minimum(Scalar x1, const Array& x2) { return Minimum(x2, x1); }

Array Minimum(const Array& x1, const Array& x2) {
Dtype dtype = GetArithmeticResultDtype(x1, x2);
return BroadcastBinary(&MinimumImpl, x1, x2, dtype); // x1 > x2 ? x2 : x1
}

Array Exp(const Array& x) {
Dtype dtype = GetMathResultDtype(x.dtype());
Array out = Empty(x.shape(), dtype, x.device());
Expand Down
1 change: 1 addition & 0 deletions chainerx_cc/chainerx/routines/math.h
Expand Up @@ -81,6 +81,7 @@ Array Maximum(Scalar x1, const Array& x2);

Array Minimum(const Array& x1, Scalar x2);
Array Minimum(Scalar x1, const Array& x2);
Array Minimum(const Array& x1, const Array& x2);

Array Exp(const Array& x);
Array Log(const Array& x);
Expand Down
39 changes: 39 additions & 0 deletions tests/chainerx_tests/unit_tests/routines_tests/test_math.py
Expand Up @@ -2008,3 +2008,42 @@ def test_max_invalid_shapes_and_axis(device, array, axis, dtype, is_module):
chainerx.max(a, axis)
else:
a.max(axis)


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
# Special shapes
chainer.testing.product({
'shape': [(), (0,), (1,), (2, 0, 3), (1, 1, 1), (2, 3)],
'in_dtypes,out_dtype': (
_make_same_in_out_dtypes(2, chainerx.testing.numeric_dtypes)),
'input_lhs': ['random'],
'input_rhs': ['random'],
'is_module': [False],
})
# is_module
+ chainer.testing.product({
'shape': [(2, 3)],
'in_dtypes,out_dtype': (
_make_same_in_out_dtypes(2, chainerx.testing.numeric_dtypes)),
'input_lhs': ['random'],
'input_rhs': ['random'],
'is_module': [True, False],
})
# TODO(aksub99): Add tests for inf and NaN.
))
class TestMinimum(BinaryMathTestBase, op_utils.NumpyOpTest):

def func(self, xp, a, b):
return xp.minimum(a, b)


@pytest.mark.parametrize_device(['native:0', 'cuda:0'])
@pytest.mark.parametrize('dtypes', _in_out_dtypes_arithmetic_invalid)
def test_minimum_invalid_dtypes(device, dtypes):
(in_dtype1, in_dtype2), _ = dtypes
shape = (3, 2)
a = chainerx.array(array_utils.uniform(shape, in_dtype1))
b = chainerx.array(array_utils.uniform(shape, in_dtype2))
with pytest.raises(chainerx.DtypeError):
chainerx.minimum(a, b)

0 comments on commit 2d784df

Please sign in to comment.