-
Notifications
You must be signed in to change notification settings - Fork 203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tch-rs update for PyTorch 1.9 #163
Comments
Hello @LaurentMazare , Thank you very much for the heads up and opportunity to have a look before the merge. I think this is a great change, and if my understanding is correct this will expose some functions that until now were not available in the Rust bindings to Libtorch. I agree the name is now far more meaningful. The cost is a slightly stronger deviation to Pytorch (but closer to Libtorch) and potentially more verbose code. I believe this should not affect readability to a large extend. I had a look at the example refactored already, and noted 2 main cases where multiple version of functions exists:
For the second case, it would be great to keep the default version (corresponding to the minimum number of arguments?) the same, and adding additional inputs to the function name for other version. This would mean that For the first case (different types of inputs), I find the change to be at times a bit inconvenient, especially for scalars that also take 0-dimension tensor as an input (which are in the vast majority of cases used with scalars), e.g.
fn gt<T: SCALAR_TENSOR> {
self.f_gt(other).unwrap()
} |
Thanks, that's some great ideas. |
The 1.9 update to tch-rs has now been merged, thanks again for the feedback (and don't hesitate if you see any way we can make tch-rs better). |
Hi, and thanks for all the work on this amazing crate.
We're about to merge a PR updating
tch-rs
to use PyTorch 1.9 LaurentMazare/tch-rs#378.As part of this change, we also took the opportunity to change the way overloaded functions are named from using an index to using the "overload-name" specified in the PyTorch api, e.g.
sum1
would becomesum_dim_intlist
. Hopefully this will help having more consistent and more stable namings but just wanted to give a heads up before merging this in case you have some feedback about this change.The version will be bumped to 0.5.0 as part of this as it breaks api compatibility.
Thanks
The text was updated successfully, but these errors were encountered: