diff --git a/chainerx_cc/chainerx/python/routines.cc b/chainerx_cc/chainerx/python/routines.cc index d767d1493639..843ed9f04b9b 100644 --- a/chainerx_cc/chainerx/python/routines.cc +++ b/chainerx_cc/chainerx/python/routines.cc @@ -105,6 +105,24 @@ std::vector ArrayBodiesToArrays(std::vector array_bodies) { return arrays; } +CastingMode ParseCastingMode(const std::string& casting) { + CastingMode mode{}; + if (casting == "no") { + mode = CastingMode::kNo; + } else if (casting == "equiv") { + throw NotImplementedError{"'equiv' casting is not yet implemented."}; + } else if (casting == "safe") { + throw NotImplementedError{"'safe' casting is not yet implemented."}; + } else if (casting == "same_kind") { + throw NotImplementedError{"'same_kind' casting is not yet implemented."}; + } else if (casting == "unsafe") { + throw NotImplementedError{"'unsafe' casting is not yet implemented."}; + } else { + throw py::value_error{"Casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe'."}; + } + return mode; +} + void InitChainerxCreation(pybind11::module& m) { // creation routines // TODO(niboshi): Accept CuPy ndarray in `array` and `asarray`. In principle it's CuPy's responsibility to provide some standard @@ -818,6 +836,22 @@ void InitChainerxManipulation(pybind11::module& m) { "a"_a, "source"_a = nullptr, "destination"_a = nullptr); + m.def("copyto", + [](const ArrayBodyPtr& dst, const ArrayBodyPtr& src, const std::string& casting, Scalar where) { + CopyTo(Array{dst}, Array{src}, ParseCastingMode(casting), Full({}, where, Dtype::kBool)); + }, + "dst"_a, + "src"_a, + "casting"_a = "no", + "where"_a = true); + m.def("copyto", + [](const ArrayBodyPtr& dst, const ArrayBodyPtr& src, const std::string& casting, const ArrayBodyPtr& where) { + CopyTo(Array{dst}, Array{src}, ParseCastingMode(casting), Array{where}); + }, + "dst"_a, + "src"_a, + "casting"_a = "no", + "where"_a); } void InitChainerxActivation(pybind11::module& m) { diff --git a/chainerx_cc/chainerx/routines/manipulation.cc b/chainerx_cc/chainerx/routines/manipulation.cc index d236a2429256..19f3de96cf18 100644 --- a/chainerx_cc/chainerx/routines/manipulation.cc +++ b/chainerx_cc/chainerx/routines/manipulation.cc @@ -24,9 +24,12 @@ #include "chainerx/error.h" #include "chainerx/graph.h" #include "chainerx/kernels/creation.h" +#include "chainerx/kernels/indexing.h" #include "chainerx/kernels/misc.h" #include "chainerx/macro.h" #include "chainerx/routines/creation.h" +#include "chainerx/routines/indexing.h" +#include "chainerx/routines/routines_util.h" #include "chainerx/routines/type_util.h" #include "chainerx/shape.h" #include "chainerx/strides.h" @@ -1002,4 +1005,26 @@ Array Moveaxis(const Array& a, const Axes& source, const Axes& destination) { return a.Transpose(order); } +void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where) { + internal::CheckNoUnsafeInplace(dst, {dst, src, where}); + + switch (casting) { + case CastingMode::kNo: + if (dst.dtype() != src.dtype()) { + throw DtypeError{"Source and destination must have same dtype."}; + } + break; + default: + CHAINERX_NEVER_REACH(); + } + + const Array& src_b = src.shape() != dst.shape() ? src.BroadcastTo(dst.shape()) : src; + const Array& where_b = where.shape() != dst.shape() ? where.BroadcastTo(dst.shape()) : where; + + { + NoBackpropModeScope scope; + dst.device().backend().CallKernel(where_b, src_b, dst, dst); + } +} + } // namespace chainerx diff --git a/chainerx_cc/chainerx/routines/manipulation.h b/chainerx_cc/chainerx/routines/manipulation.h index 4cb320757453..42e35ae60bad 100644 --- a/chainerx_cc/chainerx/routines/manipulation.h +++ b/chainerx_cc/chainerx/routines/manipulation.h @@ -91,4 +91,10 @@ Array DStack(const std::vector& arrays); Array Moveaxis(const Array& a, const Axes& source, const Axes& destination); +enum class CastingMode { + kNo, +}; + +void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where); + } // namespace chainerx diff --git a/tests/chainerx_tests/unit_tests/routines_tests/test_manipulation.py b/tests/chainerx_tests/unit_tests/routines_tests/test_manipulation.py index c43246469108..57a2500aae4d 100644 --- a/tests/chainerx_tests/unit_tests/routines_tests/test_manipulation.py +++ b/tests/chainerx_tests/unit_tests/routines_tests/test_manipulation.py @@ -1531,6 +1531,104 @@ def test_moveaxis_invalid(xp, shape, source, dst): return xp.moveaxis(a, source, dst) +@op_utils.op_test(['native:0', 'cuda:0']) +@chainer.testing.parameterize(*( + chainer.testing.product({'dst_shape,src_shape,where_shape': [ + # Same Shapes + ((2, 3), (2, 3), (2, 3)), + # Broadcast Shapes + ((2, 3), (1, 3), (1, 3)), + ((2, 3), (2, 1), (1, 3)), + ((2, 3), (2, 3), (1, 3)), + ((4, 5), (4, 1), (1, 5)), + ((1, 4, 5), (1, 4, 1), (1, 1, 5)), + ((2, 3), (2, 3), (2, 3)), + # Omit where + ((2, 3), (2, 3), None), + ], + 'in_dtypes,out_dtype': dtype_utils.result_numeric_dtypes_two_arrays, + 'casting': ['no'], + }) +)) +class TestCopyTo(op_utils.NumpyOpTest): + + skip_backward_test = True + skip_double_backward_test = True + check_numpy_strides_compliance = False + + forward_accept_errors = (TypeError, chainerx.DtypeError) + + def generate_inputs(self): + dst_dtype, src_dtype = self.in_dtypes + + dst = array_utils.uniform(self.dst_shape, dst_dtype) + src = array_utils.uniform(self.src_shape, src_dtype) + where = array_utils.uniform( + self.where_shape if self.where_shape is not None else (1,), + 'float32', 0, 1) > 0.5 + + return dst, src, where + + def forward_xp(self, inputs, xp): + dst, src, where = inputs + + if xp is chainerx: + dst = dst.as_grad_stopped().copy() + src = src.as_grad_stopped() + where = where.as_grad_stopped() + else: + dst = dst.copy() + + kwargs = {} + if self.casting is not None: + kwargs['casting'] = self.casting + if self.where_shape is not None: + kwargs['where'] = where + + xp.copyto(dst, src, **kwargs) + + return dst, + + +@op_utils.op_test(['native:0', 'cuda:0']) +@chainer.testing.parameterize(*( + chainer.testing.product({ + 'where': [True, False, 2, 1.2], + }) +)) +class TestCopyToScalarWhere(op_utils.NumpyOpTest): + + skip_backward_test = True + skip_double_backward_test = True + check_numpy_strides_compliance = False + + def generate_inputs(self): + dst = array_utils.uniform((2, 3), 'float32') + src = array_utils.uniform((2, 3), 'float32') + + return dst, src + + def forward_xp(self, inputs, xp): + dst, src = inputs + + if xp is chainerx: + dst = dst.as_grad_stopped().copy() + src = src.as_grad_stopped() + else: + dst = dst.copy() + + xp.copyto(dst, src, casting='no', where=self.where) + + return dst, + + +def test_copyto_invalid_casting(): + a = array_utils.create_dummy_ndarray(chainerx, (2, 3), 'float32') + b = array_utils.create_dummy_ndarray(chainerx, (3,), 'float32') + with pytest.raises(ValueError): + chainerx.copyto(a, b, casting='some_invalid_casting') + + @op_utils.op_test(['native:0', 'cuda:0']) @chainer.testing.parameterize_pytest('a_shape,b_shape', _reshape_shape) @chainer.testing.parameterize_pytest('is_module', [True, False])