Skip to content

Commit

Permalink
Merge pull request #6821 from hvy/bp-6779-add-logical-ops
Browse files Browse the repository at this point in the history
[backport] [chainerx] add `logical_and` `logical_or`
  • Loading branch information
asi1024 committed Apr 11, 2019
2 parents 0310666 + 7614899 commit 27e940b
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 0 deletions.
6 changes: 6 additions & 0 deletions chainerx/__init__.pyi
Expand Up @@ -656,6 +656,12 @@ def log_softmax(
axis: tp.Optional[tp.Union[int, tp.List[int]]]=None) -> ndarray: ...


def logical_and(x1: ndarray, x2: ndarray) -> ndarray: ...


def logical_or(x1: ndarray, x2: ndarray) -> ndarray: ...


def logical_not(x: ndarray) -> ndarray: ...


Expand Down
36 changes: 36 additions & 0 deletions chainerx/_docs/routines.py
Expand Up @@ -495,6 +495,42 @@ def _docs_logic():
.. seealso:: :data:`numpy.logical_not`
""")

_docs.set_doc(
chainerx.logical_and,
"""logical_and(x1, x2)
Returns an array of x1 AND x2 element-wise.
Args:
x1 (~chainerx.ndarray): Input array.
x2 (~chainerx.ndarray): Input array.
Returns:
:class:`~chainerx.ndarray`: Output array of type bool.
Note:
During backpropagation, this function does not propagate gradients.
.. seealso:: :data:`numpy.logical_and`
""")

_docs.set_doc(
chainerx.logical_or,
"""logical_or(x1, x2)
Returns an array of x1 OR x2 element-wise.
Args:
x1 (~chainerx.ndarray): Input array.
x2 (~chainerx.ndarray): Input array.
Returns:
:class:`~chainerx.ndarray`: Output array of type bool.
Note:
During backpropagation, this function does not propagate gradients.
.. seealso:: :data:`numpy.logical_or`
""")

_docs.set_doc(
chainerx.greater,
"""greater(x1, x2)
Expand Down
48 changes: 48 additions & 0 deletions chainerx_cc/chainerx/cuda/cuda_device/comparison.cu
Expand Up @@ -135,6 +135,54 @@ public:

CHAINERX_REGISTER_OP_CUDA(LogicalNotOp, CudaLogicalNotOp);

template <typename T>
struct LogicalAndImpl {
using CudaType = cuda_internal::DataType<T>;
__device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 && x2; }
};

class CudaLogicalAndOp : public LogicalAndOp {
public:
void Call(const Array& x1, const Array& x2, const Array& out) override {
Device& device = x1.device();
device.CheckDevicesCompatible(x1, x2, out);
Dtype dtype = PromoteTypes(x1.dtype(), x2.dtype());
const Array& x1_cast = x1.dtype() == dtype ? x1 : x1.AsType(dtype);
const Array& x2_cast = x2.dtype() == dtype ? x2 : x2.AsType(dtype);
CudaSetDeviceScope scope{device.index()};
VisitDtype(dtype, [&](auto pt) {
using T = typename decltype(pt)::type;
Elementwise<const T, const T, bool>(LogicalAndImpl<T>{}, x1_cast, x2_cast, out);
});
}
};

CHAINERX_REGISTER_OP_CUDA(LogicalAndOp, CudaLogicalAndOp);

template <typename T>
struct LogicalOrImpl {
using CudaType = cuda_internal::DataType<T>;
__device__ void operator()(int64_t /*i*/, CudaType x1, CudaType x2, bool& out) { out = x1 || x2; }
};

class CudaLogicalOrOp : public LogicalOrOp {
public:
void Call(const Array& x1, const Array& x2, const Array& out) override {
Device& device = x1.device();
device.CheckDevicesCompatible(x1, x2, out);
Dtype dtype = PromoteTypes(x1.dtype(), x2.dtype());
const Array& x1_cast = x1.dtype() == dtype ? x1 : x1.AsType(dtype);
const Array& x2_cast = x2.dtype() == dtype ? x2 : x2.AsType(dtype);
CudaSetDeviceScope scope{device.index()};
VisitDtype(dtype, [&](auto pt) {
using T = typename decltype(pt)::type;
Elementwise<const T, const T, bool>(LogicalOrImpl<T>{}, x1_cast, x2_cast, out);
});
}
};

CHAINERX_REGISTER_OP_CUDA(LogicalOrOp, CudaLogicalOrOp);

} // namespace
} // namespace cuda
} // namespace chainerx
40 changes: 40 additions & 0 deletions chainerx_cc/chainerx/native/native_device/comparison.cc
Expand Up @@ -110,6 +110,46 @@ class NativeLogicalNotOp : public LogicalNotOp {

CHAINERX_REGISTER_OP_NATIVE(LogicalNotOp, NativeLogicalNotOp);

class NativeLogicalAndOp : public LogicalAndOp {
public:
void Call(const Array& x1, const Array& x2, const Array& out) override {
Device& device = x1.device();
device.CheckDevicesCompatible(x1, x2, out);
Dtype dtype = PromoteTypes(x1.dtype(), x2.dtype());
const Array& x1_cast = x1.dtype() == dtype ? x1 : x1.AsType(dtype);
const Array& x2_cast = x2.dtype() == dtype ? x2 : x2.AsType(dtype);
VisitDtype(dtype, [&](auto pt) {
using T = typename decltype(pt)::type;
struct Impl {
void operator()(int64_t /*i*/, T x1, T x2, bool& out) { out = x1 && x2; }
};
Elementwise<const T, const T, bool>(Impl{}, x1_cast, x2_cast, out);
});
}
};

CHAINERX_REGISTER_OP_NATIVE(LogicalAndOp, NativeLogicalAndOp);

class NativeLogicalOrOp : public LogicalOrOp {
public:
void Call(const Array& x1, const Array& x2, const Array& out) override {
Device& device = x1.device();
device.CheckDevicesCompatible(x1, x2, out);
Dtype dtype = PromoteTypes(x1.dtype(), x2.dtype());
const Array& x1_cast = x1.dtype() == dtype ? x1 : x1.AsType(dtype);
const Array& x2_cast = x2.dtype() == dtype ? x2 : x2.AsType(dtype);
VisitDtype(dtype, [&](auto pt) {
using T = typename decltype(pt)::type;
struct Impl {
void operator()(int64_t /*i*/, T x1, T x2, bool& out) { out = x1 || x2; }
};
Elementwise<const T, const T, bool>(Impl{}, x1_cast, x2_cast, out);
});
}
};

CHAINERX_REGISTER_OP_NATIVE(LogicalOrOp, NativeLogicalOrOp);

} // namespace
} // namespace native
} // namespace chainerx
8 changes: 8 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -335,6 +335,14 @@ void InitChainerxLogic(pybind11::module& m) {
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(LessEqual(Array{x1}, Array{x2})); },
py::arg("x1"),
py::arg("x2"));
m.def("logical_and",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(LogicalAnd(Array{x1}, Array{x2})); },
py::arg("x1"),
py::arg("x2"));
m.def("logical_or",
[](const ArrayBodyPtr& x1, const ArrayBodyPtr& x2) { return MoveArrayBody(LogicalOr(Array{x1}, Array{x2})); },
py::arg("x1"),
py::arg("x2"));
m.def("logical_not", [](const ArrayBodyPtr& x) { return MoveArrayBody(LogicalNot(Array{x})); }, py::arg("x"));
}

Expand Down
12 changes: 12 additions & 0 deletions chainerx_cc/chainerx/routines/logic.cc
Expand Up @@ -78,4 +78,16 @@ Array LogicalNot(const Array& x) {
return out;
}

Array LogicalAnd(const Array& x1, const Array& x2) {
CheckLogicDtypes(x1, x2);
auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallOp<LogicalAndOp>(x1, x2, out); };
return BroadcastComparison(func, x1, x2);
}

Array LogicalOr(const Array& x1, const Array& x2) {
CheckLogicDtypes(x1, x2);
auto func = [](const Array& x1, const Array& x2, Array& out) { return x1.device().backend().CallOp<LogicalOrOp>(x1, x2, out); };
return BroadcastComparison(func, x1, x2);
}

} // namespace chainerx
18 changes: 18 additions & 0 deletions chainerx_cc/chainerx/routines/logic.h
Expand Up @@ -42,6 +42,20 @@ class LogicalNotOp : public Op {
virtual void Call(const Array& x, const Array& out) = 0;
};

class LogicalAndOp : public Op {
public:
static const char* name() { return "LogicalAnd"; }

virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0;
};

class LogicalOrOp : public Op {
public:
static const char* name() { return "LogicalOr"; }

virtual void Call(const Array& x1, const Array& x2, const Array& out) = 0;
};

// Returns an elementwise equality array.
//
// Dtype casting is not supported: if x1 and x2 have different types, DtypeError is thrown.
Expand Down Expand Up @@ -75,4 +89,8 @@ inline Array LessEqual(const Array& x1, const Array& x2) { return GreaterEqual(x
// Returns an elementwise logical negation of an array.
Array LogicalNot(const Array& x);

Array LogicalAnd(const Array& x1, const Array& x2);

Array LogicalOr(const Array& x1, const Array& x2);

} // namespace chainerx
2 changes: 2 additions & 0 deletions docs/source/chainerx/reference/routines.rst
Expand Up @@ -96,6 +96,8 @@ Logic functions
chainerx.isinf
chainerx.isnan

chainerx.logical_and
chainerx.logical_or
chainerx.logical_not

chainerx.greater
Expand Down
85 changes: 85 additions & 0 deletions tests/chainerx_tests/unit_tests/routines_tests/test_logic.py
Expand Up @@ -205,3 +205,88 @@ def forward_xp(self, inputs, xp):
a, = inputs
b = xp.logical_not(a)
return b,


def logical_and(xp, a, b):
return xp.logical_and(a, b)


def logical_or(xp, a, b):
return xp.logical_or(a, b)


_binary_logical_params = \
chainer.testing.product({
'dtypes': _expected_all_dtypes_comparison,
'func': [
logical_and, logical_or
],
'inputs': [
([], []),
([True], [True]),
([True], [False]),
]
}) + chainer.testing.product({
'dtypes': _expected_numeric_dtypes_comparison,
'func': [
logical_and, logical_or
],
'inputs': [
([0], [0]),
([0], [-0]),
([0], [1]),
([0, 1, 2], [0, 1, 2]),
([1, 1, 2], [0, 1, 2]),
([0, 1, 2], [1, 2, 3]),
([[0, 1], [2, 3]], [[0, 1], [2, 3]]),
([[0, 1], [2, 3]], [[0, 1], [2, -2]]),
([[0, 1], [2, 3]], [[1, 2], [3, 4]]),
(0, [0]),
(1, [0]),
([], [0]),
([0], [[0, 1, 2], [3, 4, 5]]),
([[0], [1]], [0, 1, 2]),
([0.2], [0.2]),
([0.2], [-0.3]),
],
}) + chainer.testing.product({
'dtypes': _expected_float_dtypes_comparison,
'func': [
logical_and, logical_or
],
'inputs': [
([0., numpy.nan], [0., 1.]),
([0., numpy.nan], [0., numpy.nan]),
([0., numpy.inf], [0., 1.]),
([0., -numpy.inf], [0., 1.]),
([numpy.inf, 1.], [numpy.inf, 1.]),
([-numpy.inf, 1.], [-numpy.inf, 1.]),
([numpy.inf, 1.], [-numpy.inf, 1.]),
([numpy.inf, 1.], [-numpy.inf, numpy.nan]),
]
})


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
_binary_logical_params
))
# Ignore warnings from numpy for NaN comparisons.
@pytest.mark.filterwarnings('ignore:invalid value encountered in ')
class TestLogicalBinary(op_utils.NumpyOpTest):

skip_backward_test = True
skip_double_backward_test = True

def generate_inputs(self):
a_object, b_object = self.inputs
a_dtype, b_dtype = self.dtypes
a = numpy.array(a_object, a_dtype)
b = numpy.array(b_object, b_dtype)
return a, b

def forward_xp(self, inputs, xp):
a, b = inputs
y1 = self.func(xp, a, b)
y2 = self.func(xp, b, a)
return y1, y2

0 comments on commit 27e940b

Please sign in to comment.