Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP

Loading…

Add rollaxis method to larry? #23

Open
WeatherGod opened this Issue · 5 comments

2 participants

@WeatherGod

I have come across a situation where I would like to roll an axis on a larry, much like how one could do so with a regular numpy array. Could that be possible?

@kwgoodman
Owner

There is currently a swapaxes. But a rollaxis sounds like a good addition.

la includes the numpy license so we can copy and paste np.rollaxis and then add label support to it.

I'll take a look when I get a chance.

@kwgoodman
Owner

Oh, wait, sorry. We'd of course just code the label support in larry.rollaxis and then just call np.rollaxis on the data part of the larry. We could copy the docstring however.

@kwgoodman
Owner

Here's a quick hack that borrows heavily from np.rollaxis. Does it behave the way you want?

import numpy as np

def rollaxis(lar, axis, start=0):
    lar.x = np.rollaxis(lar.x, axis, start)
    n = lar.ndim
    if axis < 0:
        axis += n
    if start < 0:
        start += n
    msg = 'rollaxis: %s (%d) must be >=0 and < %d'
    if not (0 <= axis < n):
        raise ValueError, msg % ('axis', axis, n)
    if not (0 <= start < n+1):
        raise ValueError, msg % ('start', start, n+1)
    if (axis < start): # it's been removed
        start -= 1
    if axis==start:
        return lar
    axes = range(0,n)
    axes.remove(axis)
    axes.insert(start, axis)
    label = [lar.label[i] for i in axes]
    lar.label = label
    return lar

Demo:

>> lar = la.rand(3,4,5,6)
>> lar.shape
   (3, 4, 5, 6)

>> np.rollaxis(lar.x, 3, 1).shape
   (3, 6, 4, 5)
>> rollaxis(lar, 3, 1).shape
   (3, 6, 4, 5)

>> np.rollaxis(lar.x, 2).shape
   (4, 3, 6, 5)
>> rollaxis(lar, 2).shape
   (4, 3, 6, 5)

>> np.rollaxis(lar.x, 1, 4).shape
   (4, 6, 5, 3)
>> rollaxis(lar, 1, 4).shape
   (4, 6, 5, 3)

BTW, np.rollaxis (np 1.6.0) does not return what the docstring examples says it returns.

@kwgoodman
Owner

I guess my hack doesn't behave the same way as np.rollaxis:

>> a = np.ones((3,4,5,6))
>> b = np.rollaxis(a, 1, 3)
>> a.shape
   (3, 4, 5, 6)
>> b.shape
   (3, 5, 4, 6)
>> 
>> lar = la.larry(a)
>> lar.shape
   (3, 4, 5, 6)
>> lar2 = rollaxis(lar, 1, 3)
>> lar.shape
   (3, 5, 4, 6)
>> lar2.shape
   (3, 5, 4, 6)
@kwgoodman
Owner

Sorry for the rapid-fire comments and corresponding bugs and mistakes. I'm trying to crank something out quickly.

Second attempt:

import numpy as np
import la

def rollaxis(lar, axis, start=0):
    x = np.rollaxis(lar.x, axis, start)
    n = lar.ndim
    if axis < 0:
        axis += n
    if start < 0:
        start += n
    msg = 'rollaxis: %s (%d) must be >=0 and < %d'
    if not (0 <= axis < n):
        raise ValueError, msg % ('axis', axis, n)
    if not (0 <= start < n+1):
        raise ValueError, msg % ('start', start, n+1)
    if (axis < start): # it's been removed
        start -= 1
    if axis==start:
        return lar
    axes = range(0,n)
    axes.remove(axis)
    axes.insert(start, axis)
    label = [lar.label[i] for i in axes]
    return la.larry(x, label, integrity=False) 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.