Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch tensor factory methods should use torch namespace instead of at #1197

Closed
sbrunk opened this issue Jul 5, 2022 · 5 comments
Closed
Labels

Comments

@sbrunk
Copy link
Contributor

sbrunk commented Jul 5, 2022

Many tensor factory methods defined in torch.java get mapped to functions in the at namespace (from the aTen tensor library underlying PyTorch) instead of the torch namespace.

See rand for instance but this is the case for most factory functions and possibly others as well.

@Namespace("at") public static native @ByVal Tensor rand(@ByVal @Cast({"int64_t*", "c10::ArrayRef<int64_t>", "std::vector<int64_t>&"}) @StdVector long[] size, @ByVal(nullValue = "at::TensorOptions{}") TensorOptions options);

Usually though, we want to call factory functions in the torch namespace because only they give us things like variables and autodiff. I.e. requires_grad does not work on factory methods from aTen.

This is also stated in the docs:

https://pytorch.org/cppdocs/#c-frontend

Unless you have a particular reason to constrain yourself exclusively to ATen or the Autograd API, the C++ frontend is the recommended entry point to the PyTorch C++ ecosystem. While it is still in beta as we collect user feedback (from you!), it provides both more functionality and better stability guarantees than the ATen and Autograd APIs.

https://pytorch.org/cppdocs/notes/faq.html#i-created-a-tensor-using-a-function-from-at-and-get-errors

Replace at:: with torch:: for factory function calls. You should never use factory functions from the at:: namespace, as they will create tensors. The corresponding torch:: functions will create variables, and you should only ever deal with variables in your code.

One thing we might have to consider is backward-compatibility, i.e. by using different names for colliding functions in the torch namespace.

@saudet
Copy link
Member

saudet commented Jul 6, 2022

I see, thanks for the information! Are all those factory functions the ones found in "torch/csrc/autograd/generated/variable_factories.h"? from_blob() already gets mapped, so we don't need to worry about that one, but I count about 50 others, including rand(), that are not currently getting mapped because they have the same signatures as other functions in the at:: namespace. Do you feel it would make sense to map all those by prefixing them with torch_?

@sbrunk
Copy link
Contributor Author

sbrunk commented Jul 6, 2022

Are all those factory functions the ones found in "torch/csrc/autograd/generated/variable_factories.h"?

I'm not too familiar with the PyTorch codebase yet but yes that looks right to me.

Do you feel it would make sense to map all those by prefixing them with torch_?

I can only speak for my use-case where it would be totally fine. Since I'm building an idiomatic Scala API on top of the JavaCPP PyTorch bindings, I'm wrapping these method calls anyway.

@saudet saudet added the bug label Jul 8, 2022
saudet added a commit that referenced this issue Jul 12, 2022
@saudet
Copy link
Member

saudet commented Jul 12, 2022

Done in commit f093384! Please give it a try with the snapshots: http://bytedeco.org/builds/

@sbrunk
Copy link
Contributor Author

sbrunk commented Jul 12, 2022

I just did a quick test after switching to the torch_... methods from the latest snapshot release:

grafik

Looking good! Thanks! And also thanks for updating to PyTorch 1.12.

I haven't tested all factory methods yet but I think we can close this issue and reopen or open a new one if something is still missing.

@sbrunk sbrunk closed this as completed Jul 13, 2022
@saudet
Copy link
Member

saudet commented Nov 3, 2022

This fix has been released as part of version 1.5.8. Thanks for reporting and for testing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants