-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
Description
Original ticket http://projects.scipy.org/numpy/ticket/1441 on 2010-03-31 by trac user kbasye, assigned to unknown.
The rollaxis() function in numeric.py allows negative arguments for both the axis and the start. For the axis, the usual Python interpretation of a negative value applies; the axis is chosen by counting from the back. For the start, however, -1 is not the last possible position, but the second to last. I understand that fixing this is a change in the function's behavior, but at least the existing behavior is undocumented :-).
In both cases, the error messages can be confusing if too-negative values are given because the error raised uses modified values of the arguments and suggests a legal range that's smaller than what the function can actually handle.
Since it's short and I've changed many lines, I'm just going to paste a suggested replacement here rather than a diff - hope that's OK.
def rollaxis(a, axis, start=0):
"""
Roll the specified axis until it lies in a given position.
Parameters
----------
a : ndarray
Input array.
axis : int
The axis to roll. The positions of the other axes do not
change relative to one another.
start : int, optional
The axis is rolled until it lies before this position.
Returns
-------
res : ndarray
Output array.
See Also
--------
roll : Roll the elements of an array by a number of positions along a
given axis.
Examples
--------
>>> a = np.ones((3,4,5,6))
>>> rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> rollaxis(a, 1, 4).shape
(3, 5, 6, 4)
First axis becomes the last:
>>> rollaxis(a, 0, -1).shape
(4, 5, 6, 3)
Last axis becomes the first:
>>> rollaxis(a, -1).shape
(6, 3, 4, 5)
"""
n = a.ndim
orig_axis, orig_start = axis, start
if axis < 0:
axis += n
if start < 0:
start += n+1
msg = 'rollaxis: %s (%d) must be >= %d and < %d'
if not (0 <= axis < n):
raise ValueError, msg % ('axis', orig_axis, -n, n)
if not (0 <= start < n+1):
raise ValueError, msg % ('start', orig_start, -n-1, n+1)
if (axis < start): # removing axis below will shift the start position
start -= 1
if axis == start:
return a
axes = range(0,n)
axes.remove(axis)
axes.insert(start, axis)
return a.transpose(axes)