Skip to content

Commit

Permalink
Merge pull request #8199 from boku13/main
Browse files Browse the repository at this point in the history
add `cupy.put_along_axis` API
  • Loading branch information
asi1024 committed Mar 1, 2024
2 parents 20ccd63 + 0ff2e72 commit 2fd0b81
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def is_available():
from cupy._functional.piecewise import piecewise # NOQA
from cupy._functional.vectorize import vectorize # NOQA
from cupy.lib._shape_base import apply_along_axis # NOQA
from cupy.lib._shape_base import put_along_axis # NOQA

# -----------------------------------------------------------------------------
# Array manipulation routines
Expand Down
69 changes: 69 additions & 0 deletions cupy/lib/_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,72 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
buff = cupy.moveaxis(buff, -1, axis)

return buff


def _make_along_axis_idx(arr_shape, indices, axis):
# compute dimensions to iterate over

if not cupy.issubdtype(indices.dtype, cupy.integer):
raise IndexError('`indices` must be an integer array')
if len(arr_shape) != indices.ndim:
raise ValueError(
"`indices` and `arr` must have the same number of dimensions")

shape_ones = (1, ) * indices.ndim
dest_dims = list(range(axis)) + [None] + \
list(range(axis + 1, indices.ndim))

# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr_shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(cupy.arange(n).reshape(ind_shape))

return tuple(fancy_index)


def put_along_axis(arr, indices, values, axis):
"""
Put values into the destination array by matching 1d index and data slices.
This iterates over matching 1d slices oriented along the specified axis in
the index and data arrays, and uses the former to place values into the
latter. These slices can be different lengths.
Functions returning an index along an axis, like `argsort` and
`argpartition`, produce suitable indices for this function.
Args:
arr : cupy.ndarray (Ni..., M, Nk...)
Destination array.
indices : cupy.ndarray (Ni..., J, Nk...)
Indices to change along each 1d slice of `arr`. This must match the
dimension of arr, but dimensions in Ni and Nj may be 1 to broadcast
against `arr`.
values : array_like (Ni..., J, Nk...)
values to insert at those indices. Its shape and dimension are
broadcast to match that of `indices`.
axis : int
The axis to take 1d slices along. If axis is None, the destination
array is treated as if a flattened 1d view had been created of it.
.. seealso:: :func:`numpy.put_along_axis`
"""

# normalize inputs
if axis is None:
if indices.ndim != 1:
raise NotImplementedError(
"Tuple setitem isn't supported for flatiter.")
# put is roughly equivalent to a.flat[ind] = values
cupy.put(arr, indices, values)
else:
axis = internal._normalize_axis_index(axis, arr.ndim)
arr_shape = arr.shape

# use the fancy index
arr[_make_along_axis_idx(arr_shape, indices, axis)] = values
1 change: 1 addition & 0 deletions docs/source/reference/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Inserting data into arrays

place
put
put_along_axis
putmask
fill_diagonal

Expand Down
58 changes: 58 additions & 0 deletions tests/cupy_tests/lib_tests/test_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,61 @@ def test_apply_along_axis_invalid_axis():
for axis in [-3, 2]:
with pytest.raises(numpy.AxisError):
xp.apply_along_axis(xp.sum, axis, a)


class TestPutAlongAxis(unittest.TestCase):

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_put_along_axis_empty(self, xp, dtype):
a = xp.array([], dtype=dtype).reshape(0, 10)
i = xp.array([], dtype=xp.int64).reshape(0, 10)
vals = xp.array([]).reshape(0, 10)
ret = xp.put_along_axis(a, i, vals, axis=0)
assert ret is None
return a

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_simple(self, xp, dtype):
a = testing.shaped_arange((3, 3, 3), xp, dtype)
indices_max = xp.argmax(a, axis=0, keepdims=True)
ret = xp.put_along_axis(a, indices_max, 0, axis=0)
assert ret is None
return a

@testing.for_all_dtypes()
def test_indices_values_arr_diff_dims(self, dtype):
for xp in [numpy, cupy]:
a = testing.shaped_arange((3, 3, 3), xp, dtype)
i_max = xp.argmax(a, axis=0, keepdims=False)
with pytest.raises(ValueError):
xp.put_along_axis(a, i_max, -99, axis=1)


@testing.parameterize(*testing.product({
'axis': [0, 1],
}))
class TestPutAlongAxes(unittest.TestCase):

def test_replace_max(self):
arr = cupy.array([[10, 30, 20], [60, 40, 50]])
indices_max = cupy.argmax(arr, axis=self.axis, keepdims=True)
# replace the max with a small value
cupy.put_along_axis(arr, indices_max, -99, axis=self.axis)
# find the new minimum, which should max
indices_min = cupy.argmin(arr, axis=self.axis, keepdims=True)
testing.assert_array_equal(indices_min, indices_max)


class TestPutAlongAxisNone(unittest.TestCase):

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_axis_none(self, xp, dtype):
a = testing.shaped_arange((3, 3), xp, dtype)
i = xp.array([1, 3])
val = xp.array([99, 100])
ret = xp.put_along_axis(a, i, val, axis=None)
assert ret is None
return a

0 comments on commit 2fd0b81

Please sign in to comment.