Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cloning): metadata, finalizer, and repeated cloning #1134

Merged
merged 28 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
aa418c5
fix(Tensor): cloning preserves attributes
sebffischer Feb 7, 2024
0dfa9f3
fix(tensor): clone preserves requires_grad
sebffischer Feb 8, 2024
4b701d2
news
sebffischer Feb 8, 2024
a744a65
fix previos commit and fix nn-module's clone
sebffischer Feb 8, 2024
4ec08fd
news
sebffischer Feb 8, 2024
7882cb9
fix previous commit
sebffischer Feb 8, 2024
b0ae5f4
remove unneeded comment
sebffischer Feb 8, 2024
39dab3e
fix(tensor): cloning left BackwardClone in grad_fn
sebffischer Feb 8, 2024
8d3ca4b
fix(Module): clone works after () / ()
sebffischer Feb 8, 2024
1a30d71
CloneBackward requires for cloning of tensor
sebffischer Feb 9, 2024
fa1971b
support clone finalizer and fix repeated cloning issue
sebffischer Feb 10, 2024
22bc164
fix: cloning of child modules
sebffischer Feb 10, 2024
7c8eafc
fix: various cloning issues
sebffischer Feb 12, 2024
2a0bc34
remove unneeded import
sebffischer Feb 12, 2024
3ddfe1d
fix and rename finalize_clone method
sebffischer Feb 12, 2024
11599d2
fix some previously introduced issues
sebffischer Feb 12, 2024
720b895
Apply suggestions from code review
sebffischer Feb 12, 2024
a528a82
remove leftover line
sebffischer Feb 12, 2024
e3af56d
insert whitespace in printer
sebffischer Feb 12, 2024
cd1c358
cleanup snapshot files
sebffischer Feb 12, 2024
b791f42
fix: cloning now works outside the toch package as well
sebffischer Feb 13, 2024
3c7148d
clarify comment
sebffischer Feb 13, 2024
d4dab9b
Merge branch 'main' into fix/cloning
dfalbel Feb 16, 2024
bbb1a5b
Rename to `clone`.
dfalbel Feb 16, 2024
f6bdad2
Merge branch 'fix/cloning' of https://github.com/sebffischer/torch in…
dfalbel Feb 16, 2024
ff505ce
refactor: remove unneeded code
sebffischer Feb 19, 2024
87ad38d
refactor clone method
sebffischer Feb 19, 2024
5c5b2d6
fix refactor
sebffischer Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Authors@R: c(
person("Christophe", "Regouby", role = c("ctb")),
person("Krzysztof", "Joachimiak", role = c("ctb")),
person("Hamada S.", "Badr", role = c("ctb")),
person("Sebastian", "Fischer", role = c("ctb")),
person(family = "RStudio", role = c("cph"))
)
Description: Provides functionality to define and train neural networks similar to
Expand Down Expand Up @@ -44,7 +45,7 @@ Imports:
desc,
safetensors (>= 0.1.1),
jsonlite
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE)
Suggests:
testthat (>= 3.0.0),
Expand Down
9 changes: 8 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# torch (development version)

- Make sure deep cloning preserve state dict attributes. (#1129)
- Make sure deep cloning of tensor and nn_module preserves class attributes and the requires_grad field. (#1129)
- Fixed that parameters and buffers of children of nn_modules were not cloned
- Cloned objects no longer reference the object from which they were cloned
- Fixed bug where nn_module's patched clone method was invalid after a call to
the internal `create_nn_module_callable()`
- Printing of `grad_fn` now appends a new line at the end.
- Added support for a private `$finalize_deep_clone()` method for `nn_module` which
allows to run some code after cloning a module.

# torch 0.12.0

Expand Down
12 changes: 8 additions & 4 deletions R/autograd.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Tensor$set("active", "requires_grad", function(requires_grad) {

Tensor$set("public", "backward", function(gradient = list(), retain_graph = create_graph,
create_graph = FALSE, inputs = NULL, ...) {

args <- list(...)
if (!is.null(args$keep_graph)) {
rlang::warn(c(
Expand All @@ -136,7 +136,7 @@ Tensor$set("public", "backward", function(gradient = list(), retain_graph = crea
)
retain_graph <- keep_graph
}

invisible(private$`__backward`(
gradient = gradient, inputs = inputs, retain_graph = retain_graph,
create_graph = create_graph
Expand Down Expand Up @@ -508,8 +508,12 @@ Node <- R6::R6Class(
initialize = function(ptr) {
self$ptr <- ptr
},
print = function() {
cat(cpp_autograd_node_name(self$ptr))
print = function(newline = TRUE) {
if (newline) {
cat(cpp_autograd_node_name(self$ptr), "\n", sep = "")
} else {
cat(cpp_autograd_node_name(self$ptr))
}
}
),
active = list(
Expand Down
128 changes: 85 additions & 43 deletions R/nn.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ nn_Module <- R6::R6Class(
out[[paste0(prefix, param_name)]] <- keepvars_or_detach(param, keepvars)
}
}

for (buf_name in names(private$buffers_)) {
buf <- private$buffers_[[buf_name]]
if (!is.null(buf) && !(buf_name %in% private$non_persistent_buffers_)) {
Expand Down Expand Up @@ -173,15 +173,15 @@ nn_Module <- R6::R6Class(
if (!self$..refer_to_state_dict..) {
with_no_grad({
param$copy_(input_param)
})
})
} else {

# setting requires grad is ignored if param is not a valid pointer
# be careful!
if (!is_null_external_pointer(param)) {
input_param$requires_grad_(param$requires_grad)
}

if (name %in% names(persistent_buffers)) {
private$buffers_[[name]] <- input_param
} else {
Expand Down Expand Up @@ -249,13 +249,13 @@ nn_Module <- R6::R6Class(
module <- private$modules_[[i]]
private$modules_[[i]] <- table[[rlang::obj_address(module)]] %||% module
}

lapply(private$modules_, function(x) x$.replace_values_from_table(table))

for (i in seq_along(private$parameters_)) {
par <- private$parameters_[[i]]
# par or buf might not be available in `table` if, for some reason they
# have already been replaced. This happens for example, when a module
# have already been replaced. This happens for example, when a module
# has the same layer twice. this also applies for modules, they might be duplicated
private$parameters_[[i]] <- table[[xptr_address(par)]] %||% par
}
Expand Down Expand Up @@ -449,6 +449,11 @@ is_nn_module <- function(x) {
#' computations depending wether the model is training or not, for example if you
#' were implementing the dropout module.
#'
#' @section Cloning:
#' To finalize the cloning of a module, you can define a private `finalize_deep_clone()` method.
#' This method is called on the cloned object when deep-cloning a module, after all the modules, parameters and
#' buffers were already cloned.
#'
#' @param classname an optional name for the module
#' @param inherit an optional module to inherit from
#' @param ... methods implementation
Expand Down Expand Up @@ -496,7 +501,7 @@ nn_module <- function(classname = NULL, inherit = nn_Module, ...,
active = active,
parent_env = e
)

init <- get_init(Module)

fun <- rlang::new_function(
Expand All @@ -516,45 +521,82 @@ create_nn_module_callable <- function(instance) {

attr(f, "class") <- instance$.classes
attr(f, "module") <- instance

# clone method was already patched, so nothing to do
if (!is.null(instance$.__enclos_env__$private$.__clone_r6__)) {
return(f)
}

# as R6's clone method is quite restrictive, we here assign the original public $clone() method to the private
# field $.__clone_r6__ and create a new patched $clone() method
# This circumvents some restrictions in R6, see e.g. this discussion: https://github.com/r-lib/R6/issues/179
rlang::env_binding_unlock(instance, "clone")
on.exit({lockBinding("clone", instance)}, add = TRUE)
clone <- instance$clone

on.exit({lockBinding(".__clone_r6__", instance$.__enclos_env__$private)}, add = TRUE)
instance$.__enclos_env__$private$.__clone_r6__ = instance$clone
sebffischer marked this conversation as resolved.
Show resolved Hide resolved

instance$clone <- function(deep = FALSE, ..., replace_values = TRUE) {
collect_state_dict <- function(instance, state_dict) {
# the parameters and buffers of child modules are retrieved below
private <- instance$.__enclos_env__$private
new_objs <- c(instance$named_parameters(recursive = FALSE), instance$named_buffers(recursive = FALSE))
if (length(new_objs)) {
names(new_objs) <- map_chr(new_objs, xptr_address)
state_dict <- append(state_dict, new_objs)
}
# also need to append a clone of the modules to this list.
# child modules can be duplicated - and have the same name
# note that we store both the modules, as well as their parameters and buffers in the state_dict
children <- instance$children
if (!length(children)) {
return(state_dict)
}
for (child in children) {
state_dict = collect_state_dict(child, state_dict)
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
}
state_dict = append(state_dict, rlang::set_names(children, map_chr(children, rlang::obj_address)))
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
return(state_dict)
}
if (deep && replace_values) {
state_dict <- append(instance$parameters, instance$buffers)
if (length(state_dict) > 0) {
names(state_dict) <- sapply(state_dict, xptr_address)

state_dict <- state_dict[!duplicated(names(state_dict))]
state_dict <- lapply(state_dict, function(x) {
out <- x$detach()$clone()
# the state_dict contains all the objects that need to be cloned and for which we also ensure that objects
# that were previously equal by reference are still equal
# To achieve this, the names of the state dict are the (external pointer) addresses of the objects
# BEFORE cloning
state_dict <- collect_state_dict(self, list())
# each unique value must only be cloned once
state_dict <- state_dict[!duplicated(names(state_dict))]
state_dict <- map(state_dict, function(x) {
if (inherits(x, "nn_module")) {
# the values are replaced below, when calling .replace_values_from_table
# this will fail when different submodules contain the same object by reference, but
# this needs a solution in R6 and not here
x$clone(deep = deep, replace_values = FALSE)
} else { # torch_tensor
# without the detaching, the clone method adds a CloneBackward node which is undessireable when cloning
# modules, as the cloned module should be independent from the clonee
out <- x$detach()$clone2()
# we need this, because of https://github.com/mlverse/torch/issues/1136
attributes(out) <- attributes(x)
# because of the detach() above, we now need to reset the requires_grad field
out$requires_grad_(x$requires_grad)
out
})

# also need to append a clone of the modules to this list.
# child modules can be duplicated - and have the same name
# child modules are also deep cloned, but we don't need to replace
# their values when cloning because we only have to do it once.
children <- instance$children
names(children) <- sapply(children, rlang::obj_address)
children <- children[!duplicated(names(children))]
children <- lapply(children, function(x) x$clone(deep = deep, replace_values = FALSE))

state_dict <- append(state_dict, children)
}
}
})
}

cloned_instance <- clone(deep = deep)

cloned_instance <- private$.__clone_r6__(deep = deep)

if (deep && replace_values) {
cloned_instance$.replace_values_from_table(state_dict)
cloned_instance$.replace_values_from_table(state_dict)
cloned_private = cloned_instance$.__enclos_env__$private
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
if (!is.null(cloned_private$finalize_deep_clone)) {
cloned_private$finalize_deep_clone()
}
}

create_nn_module_callable(cloned_instance)
}

environment(instance$clone) <- instance$.__enclos_env__

f
}

Expand Down Expand Up @@ -718,11 +760,11 @@ length.nn_sequential <- function(x) {

#' Prune top layer(s) of a network
#'
#' Prune `head_size` last layers of a nn_module in order to
#' Prune `head_size` last layers of a nn_module in order to
#' replace them by your own head, or in order to use the pruned module
#' as a sequential embedding module.
#' @param x nn_network to prune
#' @param head_size number of nn_layers to prune
#' @param head_size number of nn_layers to prune
#'
#' @return a nn_sequential network with the top nn_layer removed
#' @export
Expand All @@ -738,7 +780,7 @@ length.nn_sequential <- function(x) {
#' nn_batch_norm1d(10),
#' nn_tanh(),
#' nn_linear(10,3)
#' )
#' )
#' prune <- nn_prune_head(x, 3)
#' prune
#' }
Expand All @@ -756,7 +798,7 @@ nn_prune_head.nn_module <- nn_module(
classname = "nn_sequential",
initialize = function(x, head_size=1L) {
modules <- rlang::list2(!!!x$children[1:(length(x$children)-head_size)])
mod_names <- names(modules)
mod_names <- names(modules)
for (i in seq_along(modules)) {
self$add_module(name = mod_names[i], module = modules[[i]])
}
Expand Down Expand Up @@ -823,7 +865,7 @@ nn_module_list <- nn_module(
)

#' Container that allows named values
#'
#'
#' @param dict A named list of submodules that will be saved in that module.
#' @examples
#' nn_module <- nn_module(
Expand All @@ -845,12 +887,12 @@ nn_module_dict <- nn_module(
if (!rlang::is_named(dict)) cli::cli_abort("All elements in {.arg dict} must be named.")
for(nm in names(dict)) {
self[[nm]] <- dict[[nm]]
}
}
},
forward = function(...) {
cli::cli_abort("{.fn nn_module_dict} has {.fn forward} implementation.")
}
)
)

#' @export
`[[.nn_module_list` <- function(x, y) {
Expand Down
16 changes: 8 additions & 8 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ globalVariables(c("..", "self", "private", "N"))
cpp_torch_namespace__store_main_thread_id()

install_success <- TRUE

is_interactive <- interactive() ||
"JPY_PARENT_PID" %in% names(Sys.getenv()) ||
identical(getOption("jupyter.in_kernel"), TRUE)
# we only autoinstall if it has not explicitly disabled by setting

# we only autoinstall if it has not explicitly disabled by setting
# TORCH_INSTALL = 0
autoinstall <- is_interactive && (Sys.getenv("TORCH_INSTALL", unset = 2) != 0)

# We can also auto install if TORCH_INSTALL is requested with TORCH_INSTALL=1
autoinstall <- autoinstall || (Sys.getenv("TORCH_INSTALL", unset = 2) == "1")

# we only autoinstall if installation doesn't yet exist.
autoinstall <- autoinstall && (!torch_is_installed())

if (autoinstall) {
install_success <- tryCatch(
{
Expand All @@ -37,8 +37,8 @@ globalVariables(c("..", "self", "private", "N"))
# in interactive environments we want to ask the user for permission to
# download and install stuff. That's not necessary otherwise because the
# user has explicitly asked for installation with `TORCH_INSTALL=1`.
if (is_interactive) {
get_confirmation() # this will error of response is not true.
if (is_interactive) {
get_confirmation() # this will error of response is not true.
}
install_torch(.inform_restart = FALSE)
TRUE
Expand Down
Loading