Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloning of torch module behaves unexpectedly #1126

Closed
sebffischer opened this issue Dec 18, 2023 · 3 comments · Fixed by #1129
Closed

Cloning of torch module behaves unexpectedly #1126

sebffischer opened this issue Dec 18, 2023 · 3 comments · Fixed by #1129

Comments

@sebffischer
Copy link
Collaborator

Thanks a lot already for fixing the class of the cloned torch module!
I still observed some differences between the cloned object and the clonee, see below where the weight of the linear layer is missing the "nn_parameter" class.

library(mlr3torch)
#> Loading required package: mlr3
#> Loading required package: mlr3pipelines
#> Loading required package: torch

a = nn_linear(1, 1)

b = a$clone(deep = TRUE)

a
#> An `nn_module` containing 2 parameters.
#> 
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]

b
#> An `nn_module` containing 2 parameters.
#> 
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]

b$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7"

a$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7"           "nn_parameter"

Created on 2023-12-18 with reprex v2.0.2

@sebffischer
Copy link
Collaborator Author

I tried to do this myself, but because clone is autogenerated, I had to change the Declarations file (I believe), but when I try to autogenerate the functions I get:

> torchgen::generate("~/gh/torch")
Starting code generation ...
Error in `purrr::map_chr()`:In index: 1.
Caused by error:
! Result must be length 1, not 0.

@sebffischer
Copy link
Collaborator Author

This kind of still breaks when one wants to do something like nn_parameter(torch_tensor(1))$clone().
While admittedly this might be less important, I think the (autogenerated)torch_clone() method itself should just preserve the attributes.

@dfalbel
Copy link
Member

dfalbel commented Jan 26, 2024

I see, you are right! Perhaps you can add a method like clone_with_attributes around here:

print = function(n = 30) {

With the behavior you desire and then I can make the plumbing for it to be the default clone.

TBH I haven't executed much torchgen like this, almost always doing load_all() then generate() with the defaults. So maybe there's a hardcode path that doesn't work properly when running from a different directory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants