Skip to content

unravel_index fails to restore dimensions of indices arguments #28439

@gabriel-vanzandycke

Description

@gabriel-vanzandycke

Describe the issue:

I have a (N,M,M) array on which I would like to call argmax(array, axis=(1,2)). To circumvent the fact that argmax only takes one axis, I flatten the corresponding dimensions to call argmax on an (N,M*M) array and unravel the result to recover the initial dimension.

The resulting unravelled indices don't correspond to what is expected (see code below)

Reproduce the code example:

import numpy as np
a = np.array([[[1, 1, 1], [1, 1, 1], [1, 4, 1]], [[5, 1, 1],[1, 1, 1],[1, 1, 1]]])
flat_indices = np.argmax(a.reshape(2, -1), axis=-1)
indices = np.array(np.unravel_index(flat_indices, a.shape))
print(indices)
assert not np.any(indices - np.array([[0, 1], [2, 0], [1, 0]])), "Wrong unravelled indices"

Error message:

Python and NumPy Versions:

2.0.2
3.9.21 (main, Dec 11 2024, 16:24:11)
[GCC 11.2.0]

Runtime Environment:

No response

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions