Skip to content

feat: add support for specifying a tuple of axis positions in expand_dims#988

Open
kgryte wants to merge 5 commits intodata-apis:mainfrom
kgryte:feat/expand-dims-tuple
Open

feat: add support for specifying a tuple of axis positions in expand_dims#988
kgryte wants to merge 5 commits intodata-apis:mainfrom
kgryte:feat/expand-dims-tuple

Conversation

@kgryte
Copy link
Contributor

@kgryte kgryte commented Feb 5, 2026

This PR:

Notes

  • While PyTorch does not support a tuple of positions, this can be worked around in array-api-compat.

In data-apis#354, a regression
was introduced which reverted a change to the signature of `expand_dims`.
Namely, the `axis` argument should not have been made optional and
should not have had a default value.

Ref: data-apis#331
Ref: data-apis#354
@kgryte kgryte added this to the v2025 milestone Feb 5, 2026
@kgryte kgryte requested a review from ev-br February 5, 2026 10:27
@kgryte kgryte added API change Changes to existing functions or objects in the API. topic: Manipulation Array manipulation and transformation. Backport Changes involve backporting to previous versions. labels Feb 5, 2026
Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be very useful to add a comment from #760 (comment)

This behavior is semantically equivalent to calling expand_dims repeatedly with a single axis, only when the axes tuple is normalized to positive values using the final shape, is sorted, and contains no duplicates.

If ``axis`` is a tuple,

- each entry of ``axis`` must resolve to a unique axis position. If an entry is a negative integer, the entry **must** resolve to a positive axis position according to the rules described above.
- if provided an invalid axis position, the function **must** raise an exception.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy raises AxisError, which derives from IndexError (which torch.unsqueeze raises) and ValueError (which jax.numpy raises). So short of adding AxisError with a prescribed inheritance hierarchy we cannot be more specific on what exception to raise.

@kgryte
Copy link
Contributor Author

kgryte commented Feb 5, 2026

@ev-br Added the desired note. I believe this is ready for another review.



def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: int) -> array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucascolley
Copy link
Member

Let's open an issue on merge of this to plan a deprecation over at https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html. I'm not sure exactly what strategy is appropriate, maybe good to discuss.

axis: Union[int, Tuple[int, ...]]
axis position(s) (zero-based). If ``axis`` is an integer,

- a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea: would it be clearer here to talk about valid indices in terms of the output dimensions? Then this would change to

- a valid axis position **must** reside on the semi-open interval ``[-M, M)` where
  `M = x.ndim + 1` is the number of dimensions of the *output* array.

then the tuple version of this would be identical, except it would say M = ndim(x) + len(axis)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

API change Changes to existing functions or objects in the API. Backport Changes involve backporting to previous versions. topic: Manipulation Array manipulation and transformation.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: add support for a tuple of axes in expand_dims

4 participants