Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: np.argmax, np.argmin, np.take_along_axis, and np.put_along_axis should support tuple axis a la np.max and np.min #25623

Open
JasonGross opened this issue Jan 18, 2024 · 5 comments

Comments

@JasonGross
Copy link

Proposed new feature or change:

scipy/scipy#19549 is blocked on having argmax support all axis arguments that max supports. The specification of tuple argmax is that np.take_along_axis(x, np.argmax(x, axis=axis, keepdims=True), axis=axis) should always equal np.max(x, axis=axis, keepdims=True).

See also #25622 and #9283, where @eric-wieser eventually proposes this constraint in #9283 (comment)

@jakevdp
Copy link
Contributor

jakevdp commented Jan 19, 2024

What would you propose argmax should return in the case of a tuple axis? would it be a tuple of index arrays similar to multi-dimensional jnp.nonzero?

@JasonGross JasonGross changed the title ENH: np.argmax and np.argmin should support tuple axis a la np.max and np.min ENH: np.argmax, np.argmin, np.take_along_axis, and np.put_along_axis should support tuple axis a la np.max and np.min Jan 19, 2024
@JasonGross
Copy link
Author

JasonGross commented Jan 19, 2024

A tuple of index arrays sounds fine to me, and that's what's proposed at #9283 (comment), right?

Then again, how would this shape-aware result option work together with the existing axis keyword already resulting in a non-scalar return value?

I think there's one obvious way to deal with this. As an example:

>>> ret = argmax(np.empty((A, B, C, D, E)), axes=(0, 2))
>>> type(ret)
tuple
>>> len(ret)  # == len(axes)
2
>>> ret[0].shape
(B, D, E)
>>> ret[1].shape
(B, D, e)

With that and keepdims, you'd get arr[argmax(arr, axes, keepdims=True)] == max(arr, keepdims=True) for any dimensionality, which seems super-desirable to me

In pseudocode, I'd expect:

def argmax(arr, axes, keepdims)
    was_tuple = isinstance(axes, tuple):
    if not was_tuple:
        axes = (axes,)

    shape = np.array(arr.shape)
    axis_mask = np.array([i in axes for i in range(arr.ndim)])
    shape[axis_mask] = 1
    ret = tuple(np.empty(shape) for _ in axes)

    # do the actual work

    if not keepdims:
        ret = tuple(r.reshape(shape[~axis_mask]) for r in ret)

    if not was_tuple:
        return ret[0]

Originally posted by @eric-wieser in #9283 (comment)

My only desideratum in requesting this is that np.take_along_axis(x, np.argmax(x, axis=axis, keepdims=True), axis=axis) is the same as np.max(x, axis=axis, keepdims=True); I need a version of x[x == np.max(x, axis=axis, keepdims=True)] = 0 that only sets a single value to 0 along the axis, even in cases where there are multiple copies of the max.

@jakevdp
Copy link
Contributor

jakevdp commented Jan 19, 2024

It sounds like your full proposal would require enhancing take_along_axis to support tuple indices and tuple axis as well.

@JasonGross
Copy link
Author

Yes, I've updated the title to reflect this.

@JasonGross
Copy link
Author

Notably, arr[argmax(arr, axes, keepdims=True)] == max(arr, keepdims=True) already fails to hold, demonstrating why slicing is not adequate and we need take_along_axis and put_along_axis instead:

x = np.array([[0, 1, 2, 3], [4, 5, 6, 7]])
x_argmax = np.argmax(x, axis=1, keepdims=True)
print(x_argmax)
print(x[x_argmax])
# IndexError: index 3 is out of bounds for axis 0 with size 2
x = np.array([[0, 4], [1, 5], [2, 6], [3, 7]])
x_argmax = np.argmax(x, axis=1, keepdims=True)
print(x[x_argmax])
# [[[1 5]]
#  [[1 5]]
#  [[1 5]]
#  [[1 5]]]
print(np.max(x, axis=1, keepdims=True))
# [[4]
#  [5]
#  [6]
#  [7]]
x = np.array([[1000, 1], [1000, 1]])
x_argmax = np.argmax(x, axis=0, keepdims=True)
print(x[x_argmax])
# [[[1000    1]
#   [1000    1]]]
print(np.max(x, axis=0, keepdims=True))
# [[1000    1]]
x_argmax = np.argmax(x, axis=1, keepdims=True)
print(x[x_argmax])
# [[[1000    1]]
#  [[1000    1]]]
print(np.max(x, axis=1, keepdims=True))
# [[1000]
#  [1000]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants