Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cloning): metadata, finalizer, and repeated cloning #1134

Merged
merged 28 commits into from
Feb 20, 2024

Conversation

sebffischer
Copy link
Collaborator

Addresses #1126, where you offered to take care of the renaming.
I.e. the $clone2() method should just be renamed to $clone(), which currently has no effect, as the find_method() function (iirc) still finds the torch_clone() function that is auto-generated.

@sebffischer
Copy link
Collaborator Author

Also, the Cpp code to clone tensors / buffers / parameters behaves differently, right?

@sebffischer
Copy link
Collaborator Author

sebffischer commented Feb 8, 2024

I also think it would be nice if torch_clone(tensor) and tensor$clone() behaved identically.

@sebffischer
Copy link
Collaborator Author

There is also another discrepancy between the clonee and the cloned object:

library(torch)
lin = nn_linear(1, 1)
lin$parameters$weight$requires_grad
#> [1] TRUE
lin$clone(deep = TRUE)$parameters$weight$requires_grad
#> [1] FALSE

Created on 2024-02-08 with reprex v2.0.2

@sebffischer
Copy link
Collaborator Author

one more:

library(torch)
nn_linear(1, 1)$train()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment
nn_linear(1, 1)$eval()$clone(deep = TRUE)
#> Error in FUN(X[[i]], ...): not an environment

Created on 2024-02-08 with reprex v2.0.2

@sebffischer
Copy link
Collaborator Author

Also, cloning repeatedly causes issues and builds up a structure of parent environments.
This is, because the clone that is being retrieved here

torch/R/nn.R

Line 521 in 0e9fdd7

clone <- instance$clone
when cloning the second time is already the patched version and not the original R6 clone implementation.
The patched version still has the original version as its enclosing environment so it should still work but will repeatedly call the clone() method of the enclosing environment (

torch/R/nn.R

Line 549 in 0e9fdd7

cloned_instance <- clone(deep = deep)
) until it reaches the top-level clone call, i.e. R6's clone method.

library(torch)

n = nn_linear(1, 1)
head(attr(n, "module")$clone, n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n, "module")$clone |> environment() |> with(clone), n = 1)
#>                           
#> 1 function (deep = FALSE)

n1 = n$clone(deep = TRUE)
head(attr(n1, "module")$clone, n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone), n = 1)
#>                                                       
#> 1 function (deep = FALSE, ..., replace_values = TRUE)
head(attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> with(clone), n = 1)
#>                           
#> 1 function (deep = FALSE)

identical(
  attr(n1, "module")$clone |> environment(),
  attr(n1, "module")$clone |> environment() |> with(clone) |> environment()
)
#> [1] FALSE

attr(n1, "module")$clone |> environment() |> names()
#> [1] "clone"    "f"        "instance"
attr(n1, "module")$clone |> environment() |> with(clone) |> environment() |> names()
#> [1] "clone"    "f"        "instance"

Created on 2024-02-10 with reprex v2.0.2

@sebffischer sebffischer changed the title fix(Tensor): cloning preserves attributes fix(cloing): metadata, finalizer, and repeated cloning Feb 10, 2024
@sebffischer
Copy link
Collaborator Author

Also, I included support for a private clone finalizer method. In mlr3torch we need something like this, because we have an nn_module that has an R6 class containing modules, but these modules also need to be registered in the nn_module class. So the reference identity of these objects needs to be kept when cloning and the only solution I came up with is to allow for some hook after calling clone. Let me know what you think :)

@sebffischer sebffischer marked this pull request as draft February 10, 2024 08:36
@sebffischer sebffischer changed the title fix(cloing): metadata, finalizer, and repeated cloning fix(cloning): metadata, finalizer, and repeated cloning Feb 10, 2024
@sebffischer
Copy link
Collaborator Author

I also just saw: r-lib/R6#273, which would make the post_clone hook officially supported by R6.

@sebffischer
Copy link
Collaborator Author

sebffischer commented Feb 10, 2024

So there is at least one more issue shown in the reprex below. When creating the state_dict, the parameters of the children are not being collected here:

torch/R/nn.R

Line 543 in 0e9fdd7

children <- lapply(children, function(x) x$clone(deep = deep, replace_values = FALSE))

I think we need to recurse through the children and things should probably work.

Setting replace_values = replace_values might also work in some cases, but this runs into the issue when different submodules reference the same tensors.

library(torch)
nn_test = nn_module("test", initialize = function() {
  self$l = nn_module_list(list(nn_linear(1, 1)))
  },
  forward = function(x) {
    self$l[[1]](x)
  }
)()

nn_test1 = nn_test$clone(deep = TRUE)
nn_test$clone(deep = TRUE)
#> An `nn_module` containing 2 parameters.
#> 
#> ── Modules ─────────────────────────────────────────────────────────────────────
#> • l: <nn_module_list> #2 parameters
l1 = nn_test$l$modules[[2]]
l2 = nn_test1$l$modules[[2]]
identical(l1, l2)
#> [1] TRUE

Created on 2024-02-10 with reprex v2.0.2

stills needs some cleanup but in principle it should be working
@sebffischer
Copy link
Collaborator Author

The current workaround for the clone method caused the cloned object to reference the original object.
It was caused by this line:

torch/R/nn.R

Line 523 in 0e9fdd7

instance$clone <- function(deep = FALSE, ..., replace_values = TRUE) {

This implied that the size of the cloned object was larger than the original object:

library(torch)

pryr::object_size(nn_relu())
#> 484.47 kB
pryr::object_size(nn_relu()$clone(deep = TRUE))
#> 492.42 kB

Created on 2024-02-12 with reprex v2.0.2

@sebffischer sebffischer marked this pull request as ready for review February 12, 2024 10:09
@sebffischer
Copy link
Collaborator Author

@dfalbel I am done here and would love to get your feedback whether you think these changes make sense :)

@sebffischer
Copy link
Collaborator Author

Ok, now I think it is actually ready from my side

the encapsulation of the patched clone method of the nn_module
caused private functions like xptr_address to be inaccessible
@sebffischer sebffischer marked this pull request as ready for review February 13, 2024 16:45
@sebffischer
Copy link
Collaborator Author

@dfalbel we can also have a call where I can explain some of the changes if you have the time / you think this is useful or necessary. Otherwise I can also give more details here

@dfalbel
Copy link
Member

dfalbel commented Feb 16, 2024

Edit: please see comment below, this no longer applies.

@sebffischer I wonder if we instead of renaming clone2 to clone, we could find a different name that clearly states what it does. The torch documentation clearly states:

This function is differentiable, so gradients will flow back from the result of this operation to input. To create a tensor without an autograd relationship to input see detach().

Thus renaming these methods will certainly cause problems in other codebases that rely on that behavior. For instance, it breaks some torch optimizers that do use clone() for that.

Besides that, the PR looks great! Thank you very much for working on this.

R/tensor.R Outdated Show resolved Hide resolved
@dfalbel
Copy link
Member

dfalbel commented Feb 16, 2024

Ok, so renaming works fine, we just had to support the other arguments to clone() that were not specified.
The only thing I'm not convinced is that we want to set requires_grad, do you know why exactly we need this? Is it not respected?

@sebffischer
Copy link
Collaborator Author

@dfalbel Thanks for the feedback. Regarding the call to $requires_grad_(): think I mixed something up. Indeed, the $requires_grad field does not need to be modified manually in the tensor's $clone() method.

@sebffischer
Copy link
Collaborator Author

please don't merge yet I want to add one more refactor

@sebffischer
Copy link
Collaborator Author

Done now

@dfalbel
Copy link
Member

dfalbel commented Feb 20, 2024

Thanks @sebffischer ! Looks great!
I also ran luz and minhub tests and they all passed against this version.

@dfalbel dfalbel merged commit e4bfb5b into mlverse:main Feb 20, 2024
6 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants