Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.atleast_1d batching rule implementation #219

Closed
DiffeoInvariant opened this issue Oct 21, 2021 · 10 comments
Closed

torch.atleast_1d batching rule implementation #219

DiffeoInvariant opened this issue Oct 21, 2021 · 10 comments

Comments

@DiffeoInvariant
Copy link

Hi functorch devs! I'm filing this issue because my code prints the following warning:

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::atleast_1d. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /tmp/pip-req-build-ytawxmfk/functorch/csrc/BatchedFallback.cpp:106.)

Why Am I Using atleast_1d ?

I'm subclassing torch.Tensor because my code needs to be able to add some extra data to that class (I'm integrating PyTorch's AD system with another AD system to be able to call torch functions from inside a PDE solve, which is why I also inherit from a class called OverloadedType), which is named _block_variable; e.g. the subclass looks like

class MyTensor(torch.Tensor, OverloadedType):
    _block_variable = None

    @staticmethod
    def __new__(cls, x, *args, **kwargs):
        return super().__new__(cls, x, *args, **kwargs)

    def __init__(self, x, block_var=None):
        super(OverloadedType, self).__init__()
        self._block_variable = block_var or BlockVariable(self)
        

    def to(self, *args, **kwargs):
        new = Tensor([])
        tmp = super(torch.Tensor, self).to(*args, **kwargs)
        new.data = tmp.data
        new.requires_grad = tmp.requires_grad
        new._block_variable = self._block_variable
        return new

     ... #some subclass-specific methods etc

This causes problems when I have code that does stuff like torch.tensor([torch.trace(x), torch.trace(x @ x)]) where x is a square MyTensor; the torch.tensor() call raises an exception related to taking the __len__ of a 0-dimentional tensor (the scalar traces). So instead, I do torch.cat([torch.atleast_1d(torch.trace(x)), torch.atleast_1d(torch.trace(x @ x))]), which works. However, this function is functorch.vmap-ed, which triggers the performance warning. It would be great if I could either get the naive implementation (using torch.tensor instead of torch.cat) to work, or if a batch rule for atleast_1d() were to be implemented.

Thank you for any help you can provide!

@Chillee
Copy link
Contributor

Chillee commented Oct 21, 2021

I suspect that we can't make torch.tensor([torch.trace(x), torch.trace(x @ x)]) work - fundamentally, torch.tensor isn't an operator, not to mention that I don't know how the semantics of something like torch.tensor([BatchedTensor, NormalTensor]) should look.

We can definitely add a batching rule for torch.atleast_1d - we'll get on that :)

btw, does the rest of your code work other than this? I'm a little bit surprised that vmapping over your tensor subclasses works, although it's very cool that it does :)

@DiffeoInvariant
Copy link
Author

Thank you!

Yeah, the rest of my code works great! The reason for it working might have something to do with the fact that the computations themselves don't actually require the Tensors to be MyTensor, regular torch.Tensor is fine (and at the beginning and end of the computation where it does matter, I ensure that the relevant objects are of type MyTensor instead of base torch.Tensor), and also that since Torch uses __torch_function__ to call methods, most operations preserve subclasses just fine. The deeper reason for the slightly funny pattern is that I'm writing code that works with either JAX or Torch, and is able to select at runtime which of those libraries performs the computations.

For what it's worth, in my code it's always something like torch.tensor([BatchedTensor, BatchedTensor]) (never a mix of BatchedTensor and NormalTensor), and I arrived at that idea for a pattern because in JAX, lambda x: jnp.array([jnp.trace(x), jnp.trace(x @ x)]) works as expected when jax.vmap-ed.

If you want to see what I'm doing, my code can be found at https://gitlab.com/crikit/crikit/-/merge_requests/99 and the specific lines I'm talking about here can be seen at https://gitlab.com/crikit/crikit/-/blob/emily/torch-adjoint/crikit/invariants/invariants.py#L1946 (note that when using the torch backend, backend.concatenate() is just torch.cat())

@zou3519 zou3519 added the actionable It is clear what should be done for this issue label Oct 22, 2021
@Chillee
Copy link
Contributor

Chillee commented Oct 25, 2021

Closed in this commit: f99ed9e

btw, atleast_1d can take a tuple of tensors at once, so you should be able to do something like

torch.cat([torch.atleast_1d((torch.trace(x), torch.trace(x @ x))])

@zou3519 zou3519 removed the actionable It is clear what should be done for this issue label Oct 25, 2021
@DiffeoInvariant
Copy link
Author

@Chillee I just reinstalled functorch with
pip uninstall functorch && pip install --user "git+https://github.com/facebookresearch/functorch.git" and now when I run python -c "import functorch" I get

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/emily/.local/lib/python3.7/site-packages/functorch/__init__.py", line 25, in <module>
    from ._src.operator_authoring import pointwise_operator
  File "/home/emily/.local/lib/python3.7/site-packages/functorch/_src/operator_authoring.py", line 18, in <module>
    "sin": _te.sin,
AttributeError: module 'torch._C._te' has no attribute 'sin'

I figured I'd mention it here because it's related to the fix for this issue but would be happy to open a new issue if that is preferable

@Chillee
Copy link
Contributor

Chillee commented Oct 25, 2021

@DiffeoInvariant what version of PyTorch do you have installed? Sounds like it might be a PyTorch versioning problem.

TBH though, I think we should just stop importing that file by default haha - it's not relevant to vmap.

@DiffeoInvariant
Copy link
Author

@Chillee I ran
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
as your README.md suggests (which gives me version 1.11.0.dev20211015+cpu)

@vfdev-5
Copy link
Contributor

vfdev-5 commented Oct 25, 2021

@DiffeoInvariant maybe you can add --upgrade to update over the previous nightly:

pip install --upgrade --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

@Chillee
Copy link
Contributor

Chillee commented Oct 25, 2021

@DiffeoInvariant I moved the op_authoring stuff to a different namespace - try reinstalling functorch?

@DiffeoInvariant
Copy link
Author

@Chillee That fixed it for me! Thank you for the quick response!

@Chillee
Copy link
Contributor

Chillee commented Oct 27, 2021

Since I think the issues here are resolved, I'm going to close this issue :)

Please let us know if you run into any other issues. Thanks!

@Chillee Chillee closed this as completed Oct 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants