Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: np.argmax and np.argmin with axis=None, keepdims=True should be compatible with np.take_along_axis and np.put_along_axis #25622

Open
JasonGross opened this issue Jan 18, 2024 · 0 comments

Comments

@JasonGross
Copy link

Proposed new feature or change:

Consider the code

import numpy as np
x = np.array([[0, 1], [2, 3]])
x_argmax1 = np.argmax(x, axis=None, keepdims=True)
x_argmax2 = np.argmax(x, axis=None, keepdims=False)
print(np.take_along_axis(x, x_argmax1, axis=None))
print(np.take_along_axis(x, x_argmax2, axis=None))

They both give

File ~/.local64/mambaforge/envs/default/lib/python3.11/site-packages/numpy/lib/shape_base.py:170, in take_along_axis(arr, indices, axis)
    167     arr_shape = arr.shape
    169 # use the fancy index
--> 170 return arr[_make_along_axis_idx(arr_shape, indices, axis)]

File ~/.local64/mambaforge/envs/default/lib/python3.11/site-packages/numpy/lib/shape_base.py:32, in _make_along_axis_idx(arr_shape, indices, axis)
     30     raise IndexError('`indices` must be an integer array')
     31 if len(arr_shape) != indices.ndim:
---> 32     raise ValueError(
     33         "`indices` and `arr` must have the same number of dimensions")
     34 shape_ones = (1,) * indices.ndim
     35 dest_dims = list(range(axis)) + [None] + list(range(axis+1, indices.ndim))

ValueError: `indices` and `arr` must have the same number of dimensions

This can be worked around by using

print(np.take_along_axis(x, x_argmax1.flatten(), axis=None))
print(np.take_along_axis(x, x_argmax2.flatten(), axis=None))

but it would be nice if np.take_along_axis(x, np.argmax(x, axis=axis, keepdims=True), axis=axis) were always equal to np.max(x, axis=axis, keepdims=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant