Skip to content

Commit

Permalink
Merge pull request #7565 from hvy/chx-is-unchained
Browse files Browse the repository at this point in the history
Introduce `chainerx.ndarray._is_chained`
  • Loading branch information
niboshi committed Jun 24, 2019
2 parents 3c567d3 + cbc11e7 commit 3d47091
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
12 changes: 12 additions & 0 deletions chainerx_cc/chainerx/python/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <nonstd/optional.hpp>

#include "chainerx/array.h"
#include "chainerx/array_body.h"
#include "chainerx/array_index.h"
#include "chainerx/axes.h"
#include "chainerx/backend_util.h"
Expand Down Expand Up @@ -635,6 +636,17 @@ void InitChainerxArray(pybind11::module& m) {

return list;
});
// TODO(hvy): Rename `_is_chained` to a less ambiguous function name.
c.def("_is_chained",
[](const ArrayBodyPtr& self, const nonstd::optional<BackpropId>& backprop_id) {
BackpropId actual_backprop_id = internal::GetArrayBackpropId(Array{self}, backprop_id);
actual_backprop_id.CheckValid();
if (!self->HasArrayNode(actual_backprop_id)) {
throw ChainerxError{"Array is constant with respect to the computation for backprop ID: '", actual_backprop_id, "'."};
}
return self->GetArrayNode(actual_backprop_id)->creator_op_node() != nullptr;
},
"backprop_id"_a = nullptr);
}

} // namespace python_internal
Expand Down
12 changes: 12 additions & 0 deletions tests/chainerx_tests/unit_tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,15 @@ def test_array_deepcopy(device):
chainerx.testing.assert_array_equal(
arr2,
chainerx.array([1, 2], chainerx.float32))


def test_is_chained():
arr = chainerx.array([1, 2], chainerx.float32)
with pytest.raises(chainerx.ChainerxError):
arr._is_chained()

arr.require_grad()
assert not arr._is_chained()

arr2 = 2 * arr
assert arr2._is_chained()

0 comments on commit 3d47091

Please sign in to comment.