You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The api for the numpy.ndarray transpose attribute allows it to permute an arbitrary number of indices into an arbitrary order. However, the torch.Tensor transpose attribute assumes a matrix and therefore only accepts two indices. This means something like the following will fail:
import numpy
import torch
from autoray import do, transpose
Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])
print(Tnp.transpose([2,1,3,0]).shape) # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape) # also gives (4,3,5,2)
print(Ttorch.transpose([2,1,3,0]).size()) # this fails with a TypeError
print(transpose(Ttorch, [2,1,3,0]).size()) # which means this also fails
Solution
The correct torch.Tensor attribute is permute, which has the same exact behavior as numpy.ndarray.transpose. This means that something like the following will do what we want:
import numpy
import torch
from autoray import do, transpose
Ttorch = torch.zeros([2,3,4,5])
Tnp = numpy.zeros([2,3,4,5])
print(Tnp.transpose([2,1,3,0]).shape) # gives (4,3,5,2), as expected
print(transpose(Tnp, [2,1,3,0]).shape) # also gives (4,3,5,2)
print(Ttorch.permute(2,1,3,0).size()) # also gives (4,3,5,2)
Proposed code change
I'm not sure that there is a way to incorporate this behavior in a clean, non-invasive manner. As far as I understand, the _module_aliases and _func_aliases dictionaries are not applicable since permute is only an attribute of torch.Tensor (i.e. there is no torch.permute(torch.Tensor, *args)). This therefore seems to necessitate direct modification of the autoray.transpose function (line 308). The following patch works, but it's not very clean:
The inherent challenge is that we need to alias x.transpose() to x.permute() when x is a torch.Tensor. If there is a better way than what I have suggested, let me know!
(p.s.) I found this problem via an error I obtained in quimb. I was trying to fuse multiple bonds of a quimb Tensor when using pyTorch as the backend, and this problem arose.
The text was updated successfully, but these errors were encountered:
Hey Matt, yeah pytorch is annoyingly similar but different to numpy. I'm about to push a few changes that will support transpose as well as the rest of the unit tests with torch as the backend.
Problem
The api for the numpy.ndarray
transpose
attribute allows it to permute an arbitrary number of indices into an arbitrary order. However, the torch.Tensortranspose
attribute assumes a matrix and therefore only accepts two indices. This means something like the following will fail:Solution
The correct torch.Tensor attribute is
permute
, which has the same exact behavior asnumpy.ndarray.transpose
. This means that something like the following will do what we want:Proposed code change
I'm not sure that there is a way to incorporate this behavior in a clean, non-invasive manner. As far as I understand, the
_module_aliases
and_func_aliases
dictionaries are not applicable sincepermute
is only an attribute oftorch.Tensor
(i.e. there is notorch.permute(torch.Tensor, *args)
). This therefore seems to necessitate direct modification of theautoray.transpose
function (line 308). The following patch works, but it's not very clean:current code:
patched code:
The inherent challenge is that we need to alias
x.transpose()
tox.permute()
whenx
is atorch.Tensor
. If there is a better way than what I have suggested, let me know!(p.s.) I found this problem via an error I obtained in
quimb
. I was trying tofuse
multiple bonds of aquimb
Tensor when usingpyTorch
as the backend, and this problem arose.The text was updated successfully, but these errors were encountered: