New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.atleast_1d batching rule implementation #219
Comments
I suspect that we can't make We can definitely add a batching rule for btw, does the rest of your code work other than this? I'm a little bit surprised that vmapping over your tensor subclasses works, although it's very cool that it does :) |
Thank you! Yeah, the rest of my code works great! The reason for it working might have something to do with the fact that the computations themselves don't actually require the For what it's worth, in my code it's always something like If you want to see what I'm doing, my code can be found at https://gitlab.com/crikit/crikit/-/merge_requests/99 and the specific lines I'm talking about here can be seen at https://gitlab.com/crikit/crikit/-/blob/emily/torch-adjoint/crikit/invariants/invariants.py#L1946 (note that when using the |
Closed in this commit: f99ed9e btw,
|
@Chillee I just reinstalled
I figured I'd mention it here because it's related to the fix for this issue but would be happy to open a new issue if that is preferable |
@DiffeoInvariant what version of PyTorch do you have installed? Sounds like it might be a PyTorch versioning problem. TBH though, I think we should just stop importing that file by default haha - it's not relevant to vmap. |
@Chillee I ran |
@DiffeoInvariant maybe you can add
|
@DiffeoInvariant I moved the op_authoring stuff to a different namespace - try reinstalling functorch? |
@Chillee That fixed it for me! Thank you for the quick response! |
Since I think the issues here are resolved, I'm going to close this issue :) Please let us know if you run into any other issues. Thanks! |
Hi functorch devs! I'm filing this issue because my code prints the following warning:
Why Am I Using
atleast_1d
?I'm subclassing
torch.Tensor
because my code needs to be able to add some extra data to that class (I'm integrating PyTorch's AD system with another AD system to be able to call torch functions from inside a PDE solve, which is why I also inherit from a class calledOverloadedType
), which is named_block_variable
; e.g. the subclass looks likeThis causes problems when I have code that does stuff like
torch.tensor([torch.trace(x), torch.trace(x @ x)])
wherex
is a squareMyTensor
; thetorch.tensor()
call raises an exception related to taking the__len__
of a 0-dimentional tensor (the scalar traces). So instead, I dotorch.cat([torch.atleast_1d(torch.trace(x)), torch.atleast_1d(torch.trace(x @ x))])
, which works. However, this function isfunctorch.vmap
-ed, which triggers the performance warning. It would be great if I could either get the naive implementation (usingtorch.tensor
instead oftorch.cat
) to work, or if a batch rule foratleast_1d()
were to be implemented.Thank you for any help you can provide!
The text was updated successfully, but these errors were encountered: