Skip to content

Commit

Permalink
Merge pull request #8314 from hvy/add-copyto
Browse files Browse the repository at this point in the history
Add `chainerx.copyto`
  • Loading branch information
asi1024 committed Oct 23, 2019
2 parents 6926d81 + 8313d4d commit b87fb8c
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
34 changes: 34 additions & 0 deletions chainerx_cc/chainerx/python/routines.cc
Expand Up @@ -105,6 +105,24 @@ std::vector<Array> ArrayBodiesToArrays(std::vector<ArrayBodyPtr> array_bodies) {
return arrays;
}

CastingMode ParseCastingMode(const std::string& casting) {
CastingMode mode{};
if (casting == "no") {
mode = CastingMode::kNo;
} else if (casting == "equiv") {
throw NotImplementedError{"'equiv' casting is not yet implemented."};
} else if (casting == "safe") {
throw NotImplementedError{"'safe' casting is not yet implemented."};
} else if (casting == "same_kind") {
throw NotImplementedError{"'same_kind' casting is not yet implemented."};
} else if (casting == "unsafe") {
throw NotImplementedError{"'unsafe' casting is not yet implemented."};
} else {
throw py::value_error{"Casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe'."};
}
return mode;
}

void InitChainerxCreation(pybind11::module& m) {
// creation routines
// TODO(niboshi): Accept CuPy ndarray in `array` and `asarray`. In principle it's CuPy's responsibility to provide some standard
Expand Down Expand Up @@ -818,6 +836,22 @@ void InitChainerxManipulation(pybind11::module& m) {
"a"_a,
"source"_a = nullptr,
"destination"_a = nullptr);
m.def("copyto",
[](const ArrayBodyPtr& dst, const ArrayBodyPtr& src, const std::string& casting, Scalar where) {
CopyTo(Array{dst}, Array{src}, ParseCastingMode(casting), Full({}, where, Dtype::kBool));
},
"dst"_a,
"src"_a,
"casting"_a = "no",
"where"_a = true);
m.def("copyto",
[](const ArrayBodyPtr& dst, const ArrayBodyPtr& src, const std::string& casting, const ArrayBodyPtr& where) {
CopyTo(Array{dst}, Array{src}, ParseCastingMode(casting), Array{where});
},
"dst"_a,
"src"_a,
"casting"_a = "no",
"where"_a);
}

void InitChainerxActivation(pybind11::module& m) {
Expand Down
25 changes: 25 additions & 0 deletions chainerx_cc/chainerx/routines/manipulation.cc
Expand Up @@ -24,9 +24,12 @@
#include "chainerx/error.h"
#include "chainerx/graph.h"
#include "chainerx/kernels/creation.h"
#include "chainerx/kernels/indexing.h"
#include "chainerx/kernels/misc.h"
#include "chainerx/macro.h"
#include "chainerx/routines/creation.h"
#include "chainerx/routines/indexing.h"
#include "chainerx/routines/routines_util.h"
#include "chainerx/routines/type_util.h"
#include "chainerx/shape.h"
#include "chainerx/strides.h"
Expand Down Expand Up @@ -1002,4 +1005,26 @@ Array Moveaxis(const Array& a, const Axes& source, const Axes& destination) {
return a.Transpose(order);
}

void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where) {
internal::CheckNoUnsafeInplace(dst, {dst, src, where});

switch (casting) {
case CastingMode::kNo:
if (dst.dtype() != src.dtype()) {
throw DtypeError{"Source and destination must have same dtype."};
}
break;
default:
CHAINERX_NEVER_REACH();
}

const Array& src_b = src.shape() != dst.shape() ? src.BroadcastTo(dst.shape()) : src;
const Array& where_b = where.shape() != dst.shape() ? where.BroadcastTo(dst.shape()) : where;

{
NoBackpropModeScope scope;
dst.device().backend().CallKernel<WhereKernel>(where_b, src_b, dst, dst);
}
}

} // namespace chainerx
6 changes: 6 additions & 0 deletions chainerx_cc/chainerx/routines/manipulation.h
Expand Up @@ -91,4 +91,10 @@ Array DStack(const std::vector<Array>& arrays);

Array Moveaxis(const Array& a, const Axes& source, const Axes& destination);

enum class CastingMode {
kNo,
};

void CopyTo(const Array& dst, const Array& src, CastingMode casting, const Array& where);

} // namespace chainerx
Expand Up @@ -1531,6 +1531,104 @@ def test_moveaxis_invalid(xp, shape, source, dst):
return xp.moveaxis(a, source, dst)


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({'dst_shape,src_shape,where_shape': [
# Same Shapes
((2, 3), (2, 3), (2, 3)),
# Broadcast Shapes
((2, 3), (1, 3), (1, 3)),
((2, 3), (2, 1), (1, 3)),
((2, 3), (2, 3), (1, 3)),
((4, 5), (4, 1), (1, 5)),
((1, 4, 5), (1, 4, 1), (1, 1, 5)),
((2, 3), (2, 3), (2, 3)),
# Omit where
((2, 3), (2, 3), None),
],
'in_dtypes,out_dtype': dtype_utils.result_numeric_dtypes_two_arrays,
'casting': ['no'],
})
))
class TestCopyTo(op_utils.NumpyOpTest):

skip_backward_test = True
skip_double_backward_test = True
check_numpy_strides_compliance = False

forward_accept_errors = (TypeError, chainerx.DtypeError)

def generate_inputs(self):
dst_dtype, src_dtype = self.in_dtypes

dst = array_utils.uniform(self.dst_shape, dst_dtype)
src = array_utils.uniform(self.src_shape, src_dtype)
where = array_utils.uniform(
self.where_shape if self.where_shape is not None else (1,),
'float32', 0, 1) > 0.5

return dst, src, where

def forward_xp(self, inputs, xp):
dst, src, where = inputs

if xp is chainerx:
dst = dst.as_grad_stopped().copy()
src = src.as_grad_stopped()
where = where.as_grad_stopped()
else:
dst = dst.copy()

kwargs = {}
if self.casting is not None:
kwargs['casting'] = self.casting
if self.where_shape is not None:
kwargs['where'] = where

xp.copyto(dst, src, **kwargs)

return dst,


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize(*(
chainer.testing.product({
'where': [True, False, 2, 1.2],
})
))
class TestCopyToScalarWhere(op_utils.NumpyOpTest):

skip_backward_test = True
skip_double_backward_test = True
check_numpy_strides_compliance = False

def generate_inputs(self):
dst = array_utils.uniform((2, 3), 'float32')
src = array_utils.uniform((2, 3), 'float32')

return dst, src

def forward_xp(self, inputs, xp):
dst, src = inputs

if xp is chainerx:
dst = dst.as_grad_stopped().copy()
src = src.as_grad_stopped()
else:
dst = dst.copy()

xp.copyto(dst, src, casting='no', where=self.where)

return dst,


def test_copyto_invalid_casting():
a = array_utils.create_dummy_ndarray(chainerx, (2, 3), 'float32')
b = array_utils.create_dummy_ndarray(chainerx, (3,), 'float32')
with pytest.raises(ValueError):
chainerx.copyto(a, b, casting='some_invalid_casting')


@op_utils.op_test(['native:0', 'cuda:0'])
@chainer.testing.parameterize_pytest('a_shape,b_shape', _reshape_shape)
@chainer.testing.parameterize_pytest('is_module', [True, False])
Expand Down

0 comments on commit b87fb8c

Please sign in to comment.