Skip to content

rollaxis() has confusing error messages and should maybe interpret negative start argument differently (Trac #1441) #2039

@thouis

Description

@thouis

Original ticket http://projects.scipy.org/numpy/ticket/1441 on 2010-03-31 by trac user kbasye, assigned to unknown.

The rollaxis() function in numeric.py allows negative arguments for both the axis and the start. For the axis, the usual Python interpretation of a negative value applies; the axis is chosen by counting from the back. For the start, however, -1 is not the last possible position, but the second to last. I understand that fixing this is a change in the function's behavior, but at least the existing behavior is undocumented :-).

In both cases, the error messages can be confusing if too-negative values are given because the error raised uses modified values of the arguments and suggests a legal range that's smaller than what the function can actually handle.

Since it's short and I've changed many lines, I'm just going to paste a suggested replacement here rather than a diff - hope that's OK.

def rollaxis(a, axis, start=0):
    """                                                                                                                                        
    Roll the specified axis until it lies in a given position.                                                                                 

    Parameters                                                                                                                                 
    ----------                                                                                                                                 
    a : ndarray                                                                                                                                
        Input array.                                                                                                                           
    axis : int                                                                                                                                 
        The axis to roll.  The positions of the other axes do not                                                                              
        change relative to one another.                                                                                                        
    start : int, optional                                                                                                                      
        The axis is rolled until it lies before this position.                                                                                 

    Returns                                                                                                                                    
    -------                                                                                                                                    
    res : ndarray                                                                                                                              
        Output array.                                                                                                                          

    See Also                                                                                                                                   
    --------                                                                                                                                   
    roll : Roll the elements of an array by a number of positions along a                                                                      
           given axis.                                                                                                                         

    Examples                                                                                                                                   
    --------                                                                                                                                   
    >>> a = np.ones((3,4,5,6))                                                                                                                 
    >>> rollaxis(a, 3, 1).shape                                                                                                                
    (3, 6, 4, 5)                                                                                                                               
    >>> rollaxis(a, 2).shape                                                                                                                   
    (5, 3, 4, 6)                                                                                                                               
    >>> rollaxis(a, 1, 4).shape                                                                                                                
    (3, 5, 6, 4)                                                                                                                               

    First axis becomes the last:                                                                                                               

    >>> rollaxis(a, 0, -1).shape                                                                                                               
    (4, 5, 6, 3)                                                                                                                               

    Last axis becomes the first:                                                                                                               

    >>> rollaxis(a, -1).shape                                                                                                                  
    (6, 3, 4, 5)                                                                                                                               

    """
    n = a.ndim
    orig_axis, orig_start = axis, start
    if axis < 0:
        axis += n
    if start < 0:
        start += n+1
    msg = 'rollaxis: %s (%d) must be >= %d and < %d'
    if not (0 <= axis < n):
        raise ValueError, msg % ('axis', orig_axis, -n, n)
    if not (0 <= start < n+1):
        raise ValueError, msg % ('start', orig_start, -n-1, n+1)
    if (axis < start): # removing axis below will shift the start position                                                                     
        start -= 1
    if axis == start:
        return a
    axes = range(0,n)
    axes.remove(axis)
    axes.insert(start, axis)
    return a.transpose(axes)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions