Skip to content

Commit

Permalink
Merge pull request #2314 from asi1024/take_along_axis
Browse files Browse the repository at this point in the history
Add `take_along_axis`
  • Loading branch information
takagi committed Jul 19, 2019
2 parents 4e76f37 + 6d17aa1 commit 364c8cb
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def result_type(*arrays_and_dtypes):
from cupy.indexing.indexing import choose # NOQA
from cupy.indexing.indexing import diagonal # NOQA
from cupy.indexing.indexing import take # NOQA
from cupy.indexing.indexing import take_along_axis # NOQA

from cupy.indexing.insert import place # NOQA
from cupy.indexing.insert import put # NOQA
Expand Down
47 changes: 47 additions & 0 deletions cupy/indexing/indexing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import cupy
from cupy.core import _errors


def take(a, indices, axis=None, out=None):
"""Takes elements of an array at specified indices along an axis.
Expand All @@ -24,6 +28,49 @@ def take(a, indices, axis=None, out=None):
return a.take(indices, axis, out)


def take_along_axis(a, indices, axis):
"""Take values from the input array by matching 1d index and data slices.
Args:
a (cupy.ndarray): Array to extract elements.
indices (cupy.ndarray): Indices to take along each 1d slice of ``a``.
axis (int): The axis to take 1d slices along.
Returns:
cupy.ndarray: The indexed result.
.. seealso:: :func:`numpy.take_along_axis`
"""

if indices.dtype.kind not in ('i', 'u'):
raise IndexError('`indices` must be an integer array')

if axis is None:
a = a.ravel()
axis = 0

ndim = a.ndim

if not (-ndim <= axis < ndim):
raise _errors._AxisError('Axis overrun')

axis %= a.ndim

if ndim != indices.ndim:
raise ValueError(
'`indices` and `a` must have the same number of dimensions')

fancy_index = []
for i, n in enumerate(a.shape):
if i == axis:
fancy_index.append(indices)
else:
ind_shape = (1,) * i + (-1,) + (1,) * (ndim - i - 1)
fancy_index.append(cupy.arange(n).reshape(ind_shape))

return a[fancy_index]


def choose(a, choices, out=None, mode='raise'):
return a.choose(choices, out, mode)

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Indexing Routines
cupy.ix_
cupy.unravel_index
cupy.take
cupy.take_along_axis
cupy.choose
cupy.diag
cupy.diagonal
Expand Down
14 changes: 14 additions & 0 deletions tests/cupy_tests/indexing_tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def test_take_no_axis(self, xp):
b = xp.array([[10, 5], [3, 20]])
return a.take(b)

@testing.with_requires('numpy>=1.15')
@testing.numpy_cupy_array_equal()
def test_take_along_axis(self, xp):
a = testing.shaped_random((2, 4, 3), xp, dtype='float32')
b = testing.shaped_random((2, 6, 3), xp, dtype='int64', scale=4)
return xp.take_along_axis(a, b, axis=-2)

@testing.with_requires('numpy>=1.15')
@testing.numpy_cupy_array_equal()
def test_take_along_axis_none_axis(self, xp):
a = testing.shaped_random((2, 4, 3), xp, dtype='float32')
b = testing.shaped_random((30,), xp, dtype='int64', scale=24)
return xp.take_along_axis(a, b, axis=None)

@testing.for_all_dtypes()
@testing.numpy_cupy_array_equal()
def test_diagonal(self, xp, dtype):
Expand Down

0 comments on commit 364c8cb

Please sign in to comment.