Skip to content

Commit

Permalink
Merge pull request #2856 from emcastillo/bp-2815-v6-fix-split-unordered
Browse files Browse the repository at this point in the history
[backport] Fix split and array_split with unordered indices supplied
  • Loading branch information
takagi committed Dec 23, 2019
2 parents 0eb71b1 + 21a6df6 commit 07bc22b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion cupy/core/_routines_manipulation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ cpdef array_split(ndarray ary, indices_or_sections, Py_ssize_t axis):
stride = 0
for index in indices:
index = min(index, size)
shape[axis] = index - prev
shape[axis] = max(index - prev, 0)
v = ary.view()
v.data = ary.data + prev * stride
# TODO(niboshi): Confirm update_x_contiguity flags
Expand Down
10 changes: 10 additions & 0 deletions tests/cupy_tests/manipulation_tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def test_array_split_out_of_bound2(self, xp):
a = testing.shaped_arange((0,), xp)
return xp.array_split(a, [1])

@testing.numpy_cupy_array_list_equal()
def test_array_split_unordered_sections(self, xp):
a = testing.shaped_arange((5,), xp)
return xp.array_split(a, [4, 2])

@testing.numpy_cupy_array_list_equal()
def test_array_split_non_divisible(self, xp):
a = testing.shaped_arange((5, 3), xp)
Expand Down Expand Up @@ -84,6 +89,11 @@ def test_split_out_of_bound2(self, xp):
a = testing.shaped_arange((0,), xp)
return xp.split(a, [1])

@testing.numpy_cupy_array_list_equal()
def test_split_unordered_sections(self, xp):
a = testing.shaped_arange((5,), xp)
return xp.split(a, [4, 2])

@testing.numpy_cupy_array_list_equal()
def test_vsplit(self, xp):
a = testing.shaped_arange((12, 3), xp)
Expand Down

0 comments on commit 07bc22b

Please sign in to comment.