From 057d8bf49db3d3b90901908ca8900f4c42a5bc24 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Tue, 29 Jul 2025 00:32:10 +0200 Subject: [PATCH 1/3] Adjusting update to use the same environment as spec --- R/create_keras_spec.R | 2 +- R/create_keras_spec_helpers.R | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/R/create_keras_spec.R b/R/create_keras_spec.R index 5b6d40c..50355fe 100644 --- a/R/create_keras_spec.R +++ b/R/create_keras_spec.R @@ -117,7 +117,7 @@ create_keras_spec <- function( register_core_model(model_name, mode) register_model_args(model_name, args_info$parsnip_names) register_fit_predict(model_name, mode, layer_blocks) - register_update_method(model_name, args_info$parsnip_names) + register_update_method(model_name, args_info$parsnip_names, env = env) env_poke(env, model_name, spec_fun) invisible(NULL) diff --git a/R/create_keras_spec_helpers.R b/R/create_keras_spec_helpers.R index b5b5655..a35788a 100644 --- a/R/create_keras_spec_helpers.R +++ b/R/create_keras_spec_helpers.R @@ -529,8 +529,9 @@ register_fit_predict <- function(model_name, mode, layer_blocks) { #' @param model_name The name of the new model. #' @param parsnip_names A character vector of all argument names. #' @return Invisibly returns `NULL`. Called for its side effects. +#' @param env The environment in which to create the update method. #' @noRd -register_update_method <- function(model_name, parsnip_names) { +register_update_method <- function(model_name, parsnip_names, env) { # Build function signature update_args_list <- c( list(object = rlang::missing_arg(), parameters = rlang::expr(NULL)), @@ -572,6 +573,8 @@ register_update_method <- function(model_name, parsnip_names) { body = update_body ) method_name <- paste0("update.", model_name) - rlang::env_poke(environment(), method_name, update_func) - registerS3method("update", model_name, update_func, envir = environment()) + # Poke the function into the target environment (e.g., .GlobalEnv) so that + # S3 dispatch can find it. + rlang::env_poke(env, method_name, update_func) + registerS3method("update", model_name, update_func, envir = env) } From 2f8bdf258b5864edf2be320c5ecb741b69a1a0d8 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Tue, 29 Jul 2025 00:32:48 +0200 Subject: [PATCH 2/3] Adjusting remove_keras_spec to edit parsnip registry --- R/remove_spec.R | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/R/remove_spec.R b/R/remove_spec.R index 5397c47..c9b0c42 100644 --- a/R/remove_spec.R +++ b/R/remove_spec.R @@ -29,23 +29,35 @@ #' !exists("my_temp_model") #' } remove_keras_spec <- function(model_name, env = parent.frame()) { - spec_found <- FALSE - if (exists(model_name, envir = env, inherits = FALSE)) { - obj <- get(model_name, envir = env) - if (is.function(obj)) { - remove(list = model_name, envir = env) - spec_found <- TRUE - } + # 1. Remove the spec + update fn from the user env + if ( + exists(model_name, envir = env, inherits = FALSE) && + is.function(get(model_name, envir = env)) + ) { + remove(list = model_name, envir = env) + } + update_fn <- paste0("update.", model_name) + if (exists(update_fn, envir = env, inherits = FALSE)) { + remove(list = update_fn, envir = env) + } + + # 2. Nuke every parsnip object whose name starts with model_name + model_env <- parsnip:::get_model_env() + all_regs <- ls(envir = model_env) + to_kill <- grep(paste0("^", model_name), all_regs, value = TRUE) + if (length(to_kill)) { + rm(list = to_kill, envir = model_env) + message( + "Removed from parsnip registry objects: ", + paste(to_kill, collapse = ", ") + ) } - # Also remove the associated update method - update_method_name <- paste0("update.", model_name) - # The update method is in the package namespace. `environment()` inside a - # package function returns the package namespace. - pkg_env <- environment() - if (exists(update_method_name, envir = pkg_env, inherits = FALSE)) { - remove(list = update_method_name, envir = pkg_env) + # 3. Remove the entry in get_model_env()$models + if ("models" %in% all_regs && model_name %in% model_env$models) { + model_env$models <- model_env$models[-which(model_name == model_env$models)] + message("Removed '", model_name, "' from parsnip:::get_model_env()$models") } - invisible(spec_found) + invisible(TRUE) } From 91af678501cc2d66b1c4aaeac6e3fc17f2df9349 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Tue, 29 Jul 2025 00:33:18 +0200 Subject: [PATCH 3/3] Updating test to make it more robust --- tests/testthat/test-e2e-spec-removal.R | 40 ++++++++++++++------------ 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tests/testthat/test-e2e-spec-removal.R b/tests/testthat/test-e2e-spec-removal.R index fbef708..ac04136 100644 --- a/tests/testthat/test-e2e-spec-removal.R +++ b/tests/testthat/test-e2e-spec-removal.R @@ -1,28 +1,30 @@ test_that("E2E: Model spec removal works", { - input_block_rm <- function(model, input_shape) { + skip_if_no_keras() + + model_name <- "removable_model" + + input_block <- function(model, input_shape) { keras3::keras_model_sequential(input_shape = input_shape) } - hidden_block_rm <- function(model, units = 16) { - model |> keras3::layer_dense(units = units, activation = "relu") + output_block <- function(model) { + model |> keras3::layer_dense(units = 1) } - output_block_rm <- function(model, num_classes) { - model |> keras3::layer_dense(units = num_classes, activation = "softmax") - } - - model_to_remove <- "e2e_mlp_to_remove" create_keras_spec( - model_name = model_to_remove, - layer_blocks = list( - input = input_block_rm, - hidden = hidden_block_rm, - output = output_block_rm - ), - mode = "classification" + model_name = model_name, + layer_blocks = list(input = input_block, output = output_block), + mode = "regression" ) - expect_true(exists(model_to_remove, inherits = FALSE)) - expect_true(remove_keras_spec(model_to_remove)) - expect_false(exists(model_to_remove, inherits = FALSE)) - expect_false(remove_keras_spec("a_non_existent_model")) + update_method_name <- paste0("update.", model_name) + + expect_true(exists(model_name, inherits = FALSE)) + expect_true(exists(update_method_name, inherits = FALSE)) + expect_error(parsnip:::check_model_doesnt_exist(model_name)) + + remove_keras_spec(model_name) + + expect_false(exists(model_name, inherits = FALSE)) + expect_false(exists(update_method_name, inherits = FALSE)) + expect_no_error(parsnip:::check_model_doesnt_exist(model_name)) })