diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 3ea0975e7..6d977d870 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -159,7 +159,7 @@ def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: x: array input array. axes: Tuple[int, ...] - tuple containing a permutation of ``(0, 1, ..., N-1)`` where ``N`` is the number of axes in ``x``. + tuple containing a permutation of axes. A valid axis **must** be an integer on the interval ``[-N, N)``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the function **must** determine the respective axis index by counting backward from the last axis (where ``-1`` refers to the last axis). If provided an invalid axis, the function **must** raise an exception. Returns -------