Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/create_keras_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ create_keras_spec <- function(
register_core_model(model_name, mode)
register_model_args(model_name, args_info$parsnip_names)
register_fit_predict(model_name, mode, layer_blocks)
register_update_method(model_name, args_info$parsnip_names)
register_update_method(model_name, args_info$parsnip_names, env = env)

env_poke(env, model_name, spec_fun)
invisible(NULL)
Expand Down
9 changes: 6 additions & 3 deletions R/create_keras_spec_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ register_fit_predict <- function(model_name, mode, layer_blocks) {
#' @param model_name The name of the new model.
#' @param parsnip_names A character vector of all argument names.
#' @return Invisibly returns `NULL`. Called for its side effects.
#' @param env The environment in which to create the update method.
#' @noRd
register_update_method <- function(model_name, parsnip_names) {
register_update_method <- function(model_name, parsnip_names, env) {
# Build function signature
update_args_list <- c(
list(object = rlang::missing_arg(), parameters = rlang::expr(NULL)),
Expand Down Expand Up @@ -572,6 +573,8 @@ register_update_method <- function(model_name, parsnip_names) {
body = update_body
)
method_name <- paste0("update.", model_name)
rlang::env_poke(environment(), method_name, update_func)
registerS3method("update", model_name, update_func, envir = environment())
# Poke the function into the target environment (e.g., .GlobalEnv) so that
# S3 dispatch can find it.
rlang::env_poke(env, method_name, update_func)
registerS3method("update", model_name, update_func, envir = env)
}
42 changes: 27 additions & 15 deletions R/remove_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,35 @@
#' !exists("my_temp_model")
#' }
remove_keras_spec <- function(model_name, env = parent.frame()) {
spec_found <- FALSE
if (exists(model_name, envir = env, inherits = FALSE)) {
obj <- get(model_name, envir = env)
if (is.function(obj)) {
remove(list = model_name, envir = env)
spec_found <- TRUE
}
# 1. Remove the spec + update fn from the user env
if (
exists(model_name, envir = env, inherits = FALSE) &&
is.function(get(model_name, envir = env))
) {
remove(list = model_name, envir = env)
}
update_fn <- paste0("update.", model_name)
if (exists(update_fn, envir = env, inherits = FALSE)) {
remove(list = update_fn, envir = env)
}

# 2. Nuke every parsnip object whose name starts with model_name
model_env <- parsnip:::get_model_env()
all_regs <- ls(envir = model_env)
to_kill <- grep(paste0("^", model_name), all_regs, value = TRUE)
if (length(to_kill)) {
rm(list = to_kill, envir = model_env)
message(
"Removed from parsnip registry objects: ",
paste(to_kill, collapse = ", ")
)
}

# Also remove the associated update method
update_method_name <- paste0("update.", model_name)
# The update method is in the package namespace. `environment()` inside a
# package function returns the package namespace.
pkg_env <- environment()
if (exists(update_method_name, envir = pkg_env, inherits = FALSE)) {
remove(list = update_method_name, envir = pkg_env)
# 3. Remove the entry in get_model_env()$models
if ("models" %in% all_regs && model_name %in% model_env$models) {
model_env$models <- model_env$models[-which(model_name == model_env$models)]
message("Removed '", model_name, "' from parsnip:::get_model_env()$models")
}

invisible(spec_found)
invisible(TRUE)
}
40 changes: 21 additions & 19 deletions tests/testthat/test-e2e-spec-removal.R
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
test_that("E2E: Model spec removal works", {
input_block_rm <- function(model, input_shape) {
skip_if_no_keras()

model_name <- "removable_model"

input_block <- function(model, input_shape) {
keras3::keras_model_sequential(input_shape = input_shape)
}
hidden_block_rm <- function(model, units = 16) {
model |> keras3::layer_dense(units = units, activation = "relu")
output_block <- function(model) {
model |> keras3::layer_dense(units = 1)
}
output_block_rm <- function(model, num_classes) {
model |> keras3::layer_dense(units = num_classes, activation = "softmax")
}

model_to_remove <- "e2e_mlp_to_remove"

create_keras_spec(
model_name = model_to_remove,
layer_blocks = list(
input = input_block_rm,
hidden = hidden_block_rm,
output = output_block_rm
),
mode = "classification"
model_name = model_name,
layer_blocks = list(input = input_block, output = output_block),
mode = "regression"
)

expect_true(exists(model_to_remove, inherits = FALSE))
expect_true(remove_keras_spec(model_to_remove))
expect_false(exists(model_to_remove, inherits = FALSE))
expect_false(remove_keras_spec("a_non_existent_model"))
update_method_name <- paste0("update.", model_name)

expect_true(exists(model_name, inherits = FALSE))
expect_true(exists(update_method_name, inherits = FALSE))
expect_error(parsnip:::check_model_doesnt_exist(model_name))

remove_keras_spec(model_name)

expect_false(exists(model_name, inherits = FALSE))
expect_false(exists(update_method_name, inherits = FALSE))
expect_no_error(parsnip:::check_model_doesnt_exist(model_name))
})
Loading