Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autoray transpose attribute-like function fails for torch tensors #1

Closed
mattorourke17 opened this issue Oct 22, 2019 · 3 comments
Closed

Comments

@mattorourke17
Copy link

Problem

The api for the numpy.ndarray transpose attribute allows it to permute an arbitrary number of indices into an arbitrary order. However, the torch.Tensor transpose attribute assumes a matrix and therefore only accepts two indices. This means something like the following will fail:

import numpy
import torch
from autoray import do, transpose

Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])

print(Tnp.transpose([2,1,3,0]).shape)   # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape)  # also gives (4,3,5,2)
print(Ttorch.transpose([2,1,3,0]).size()) # this fails with a TypeError
print(transpose(Ttorch, [2,1,3,0]).size())  # which means this also fails

Solution

The correct torch.Tensor attribute is permute, which has the same exact behavior as numpy.ndarray.transpose. This means that something like the following will do what we want:

import numpy
import torch
from autoray import do, transpose

Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])

print(Tnp.transpose([2,1,3,0]).shape)   # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape)  # also gives (4,3,5,2)
print(Ttorch.permute(2,1,3,0).size())  # also gives (4,3,5,2)

Proposed code change

I'm not sure that there is a way to incorporate this behavior in a clean, non-invasive manner. As far as I understand, the _module_aliases and _func_aliases dictionaries are not applicable since permute is only an attribute of torch.Tensor (i.e. there is no torch.permute(torch.Tensor, *args)). This therefore seems to necessitate direct modification of the autoray.transpose function (line 308). The following patch works, but it's not very clean:

current code:

def transpose(x, *args):
    try:
        return x.transpose(*args)
    except AttributeError:
        return do('transpose', x, *args)

patched code:

def transpose(x, *args):
    backend = infer_backend(x)
    if backend == 'torch':
        return x.permute(*args)
    else:
        try:
            return x.transpose(*args)
        except AttributeError:
            return do('transpose', x, *args)

The inherent challenge is that we need to alias x.transpose() to x.permute() when x is a torch.Tensor. If there is a better way than what I have suggested, let me know!

(p.s.) I found this problem via an error I obtained in quimb. I was trying to fuse multiple bonds of a quimb Tensor when using pyTorch as the backend, and this problem arose.

@jcmgray
Copy link
Owner

jcmgray commented Oct 23, 2019

Hey Matt, yeah pytorch is annoyingly similar but different to numpy. I'm about to push a few changes that will support transpose as well as the rest of the unit tests with torch as the backend.

@jcmgray
Copy link
Owner

jcmgray commented Oct 23, 2019

3ef2d61 should fix the problem! Let me know if not. In general so far adding torch translations where necessary seems fairly easy.

@mattorourke17
Copy link
Author

Awesome! This looks great. Thanks Johnnie!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants