Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tch-rs update for PyTorch 1.9 #163

Closed
LaurentMazare opened this issue Jun 19, 2021 · 3 comments
Closed

tch-rs update for PyTorch 1.9 #163

LaurentMazare opened this issue Jun 19, 2021 · 3 comments

Comments

@LaurentMazare
Copy link

Hi, and thanks for all the work on this amazing crate.
We're about to merge a PR updating tch-rs to use PyTorch 1.9 LaurentMazare/tch-rs#378.
As part of this change, we also took the opportunity to change the way overloaded functions are named from using an index to using the "overload-name" specified in the PyTorch api, e.g. sum1 would become sum_dim_intlist. Hopefully this will help having more consistent and more stable namings but just wanted to give a heads up before merging this in case you have some feedback about this change.
The version will be bumped to 0.5.0 as part of this as it breaks api compatibility.
Thanks

@guillaume-be
Copy link
Owner

guillaume-be commented Jun 19, 2021

Hello @LaurentMazare ,

Thank you very much for the heads up and opportunity to have a look before the merge. I think this is a great change, and if my understanding is correct this will expose some functions that until now were not available in the Rust bindings to Libtorch. I agree the name is now far more meaningful. The cost is a slightly stronger deviation to Pytorch (but closer to Libtorch) and potentially more verbose code. I believe this should not affect readability to a large extend.

I had a look at the example refactored already, and noted 2 main cases where multiple version of functions exists:

  1. Different input type (e.g. g_sub1)
  2. Additional arguments (e.g. arange1, squeeze1, mean1)

For the second case, it would be great to keep the default version (corresponding to the minimum number of arguments?) the same, and adding additional inputs to the function name for other version. This would mean that arange remains arange, but arange1 becomes arange_start. This helps a lot disambiguating, and allows exposing additional methods in the Rust bindings. I believe this is the current proposal and think this is a great change!

For the first case (different types of inputs), I find the change to be at times a bit inconvenient, especially for scalars that also take 0-dimension tensor as an input (which are in the vast majority of cases used with scalars), e.g.

  • fill_ -> fill_scalar_
  • masked_fill -> masked_fill_scalar
    I am wondering if there is not a more elegant way to handle a fixed number of inputs, but of various type. I am thinking of Rust generics which are typically used for this, but I have no experience using them in the context of bindings. Maybe you could create combinations of generic types (e.g. SCALAR_TENSOR) and have an implementation for gt of the form:
fn gt<T: SCALAR_TENSOR> {
    self.f_gt(other).unwrap()
}

@LaurentMazare
Copy link
Author

Thanks, that's some great ideas.
For (1) I just pushed some changes that should preserve the old non-indexed functions most of the time, and it indeed makes the amount of changes quite smaller.
Re (2) this indeed seems like the idiomatic way to handle parameters having multiple possible types in rust. That said the implementation is a bit tedious so I'll probably punt on this for the moment.

@LaurentMazare
Copy link
Author

The 1.9 update to tch-rs has now been merged, thanks again for the feedback (and don't hesitate if you see any way we can make tch-rs better).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants