From 73b074923f1224ba9f53872a4180d26ebe4f7b41 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:31:40 +0200 Subject: [PATCH 01/32] refactoring process to ensure there x and y for sequential and functional --- R/utils.R | 143 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 141 insertions(+), 2 deletions(-) diff --git a/R/utils.R b/R/utils.R index ae0e69a..a130fde 100644 --- a/R/utils.R +++ b/R/utils.R @@ -176,6 +176,141 @@ loss_function_keras <- function(values = NULL) { ) } +#' Process Predictor Input for Keras (Functional API) +#' +#' @description +#' Preprocesses predictor data (`x`) into a format suitable for Keras models +#' built with the Functional API. Handles both tabular data and list-columns +#' of arrays (e.g., for images), supporting multiple inputs. +#' +#' @param x A data frame or matrix of predictors. +#' @return A list containing: +#' - `x_proc`: The processed predictor data (matrix or array, or list of arrays). +#' - `input_shape`: The determined input shape(s) for the Keras model. +#' @keywords internal +#' @export +process_x_functional <- function(x) { + if (is.data.frame(x)) { + # Check if it's a multi-input scenario (multiple list-columns) + if (all(sapply(x, is.list)) && ncol(x) > 1) { + x_proc_list <- lapply(x, function(col) { + do.call(abind::abind, c(col, list(along = 0))) + }) + # For multi-input, input_shape should be a list of shapes + input_shape_list <- lapply(x_proc_list, function(arr) { + if (length(dim(arr)) > 2) dim(arr)[-1] else ncol(arr) + }) + # Add names to the lists + names(x_proc_list) <- names(x) + names(input_shape_list) <- names(x) + return(list(x_proc = x_proc_list, input_shape = input_shape_list)) + } else if (ncol(x) == 1 && is.list(x[[1]])) { + # Original case: single predictor column containing a list of arrays. + x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0))) + } else { + x_proc <- as.matrix(x) + } + } + input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc) + list(x_proc = x_proc, input_shape = input_shape) +} + +#' Process Outcome Input for Keras (Functional API) +#' +#' @description +#' Preprocesses outcome data (`y`) into a format suitable for Keras models +#' built with the Functional API. Handles both regression (numeric) and +#' classification (factor) outcomes, including one-hot encoding for classification, +#' and supports multiple outputs. +#' +#' @param y A vector or data frame of outcomes. +#' @param is_classification Logical, optional. If `TRUE`, treats `y` as +#' classification. If `FALSE`, treats as regression. If `NULL` (default), +#' it's determined from `is.factor(y)`. +#' @param class_levels Character vector, optional. The factor levels for +#' classification outcomes. If `NULL` (default), determined from `levels(y)`. +#' @return A list containing: +#' - `y_proc`: The processed outcome data (matrix or one-hot encoded array, +#' or list of these for multiple outputs). +#' - `is_classification`: Logical, indicating if `y` was treated as classification. +#' - `num_classes`: Integer, the number of classes for classification, or `NULL`. +#' - `class_levels`: Character vector, the factor levels for classification, or `NULL`. +#' @importFrom keras3 to_categorical +#' @keywords internal +#' @export +process_y_functional <- function( + y, + is_classification = NULL, + class_levels = NULL +) { + # If y is a data frame/tibble with one column, extract it to ensure it's + # processed by the single-output logic path. + if (is.data.frame(y) && ncol(y) == 1) { + y <- y[[1]] + } + + if (is.data.frame(y)) { + # Handle multiple output columns + y_proc_list <- list() # This will store the processed y for each output + class_levels_list <- list() # To store class levels for each output + + for (col_name in names(y)) { + current_y <- y[[col_name]] + current_is_classification <- is_classification %||% is.factor(current_y) + current_class_levels <- class_levels %||% levels(current_y) + + y_proc_single <- NULL + num_classes_single <- NULL + + if (current_is_classification) { + if (is.null(current_class_levels)) { + current_class_levels <- levels(current_y) + } + num_classes_single <- length(current_class_levels) + y_factored <- factor(current_y, levels = current_class_levels) + y_proc_single <- keras3::to_categorical( + as.numeric(y_factored) - 1, + num_classes = num_classes_single + ) + } else { + y_proc_single <- as.matrix(current_y) + } + y_proc_list[[col_name]] <- y_proc_single + class_levels_list[[col_name]] <- current_class_levels # Store class levels for each output + } + # Return a list containing y_proc_list and class_levels_list + return(list(y_proc = y_proc_list, class_levels = class_levels_list)) + } else { + # Original single output case + if (is.null(is_classification)) { + is_classification <- is.factor(y) + } + + y_proc <- NULL + num_classes <- NULL + if (is_classification) { + if (is.null(class_levels)) { + class_levels <- levels(y) + } + num_classes <- length(class_levels) + y_factored <- factor(y, levels = class_levels) + y_proc <- keras3::to_categorical( + as.numeric(y_factored) - 1, + num_classes = num_classes + ) + } else { + y_proc <- as.matrix(y) + } + return(list( + y_proc = y_proc, + class_levels = class_levels, + is_classification = is_classification, + num_classes = num_classes + )) + } +} + + #' Process Predictor Input for Keras #' #' @description @@ -188,7 +323,7 @@ loss_function_keras <- function(values = NULL) { #' - `input_shape`: The determined input shape for the Keras model. #' @keywords internal #' @export -process_x <- function(x) { +process_x_sequential <- function(x) { if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) { # Assumes a single predictor column containing a list of arrays. # We stack them into a single higher-dimensional array. @@ -221,7 +356,11 @@ process_x <- function(x) { #' @importFrom keras3 to_categorical #' @keywords internal #' @export -process_y <- function(y, is_classification = NULL, class_levels = NULL) { +process_y_sequential <- function( + y, + is_classification = NULL, + class_levels = NULL +) { # If y is a data frame/tibble, extract the first column if (is.data.frame(y)) { y <- y[[1]] From 738c738478e2628b6f6d3f8bdbd4e7633828270e Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:32:46 +0200 Subject: [PATCH 02/32] Ensuring build and compile sequential uses the correct process --- R/build_and_compile_model.R | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index 0b18165..3e25486 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -1,3 +1,18 @@ +#' Build and Compile a Keras Sequential Model +#' +#' @description +#' This internal helper function constructs and compiles a Keras sequential model +#' based on a list of layer blocks and other parameters. It handles data +#' processing, dynamic architecture construction, and model compilation. +#' +#' @param x A data frame or matrix of predictors. +#' @param y A vector or data frame of outcomes. +#' @param layer_blocks A named list of functions that define the layers of the +#' model. The order of the list determines the order of the layers. +#' @param ... Additional arguments passed to the function, including layer +#' hyperparameters, repetition counts for blocks, and compile/fit arguments. +#' +#' @return A compiled Keras model object. #' @noRd build_and_compile_sequential_model <- function( x, @@ -11,16 +26,18 @@ build_and_compile_sequential_model <- function( verbose <- all_args$verbose %||% 0 # Process x input - x_processed <- process_x(x) + x_processed <- process_x_sequential(x) x_proc <- x_processed$x_proc input_shape <- x_processed$input_shape # Process y input - y_processed <- process_y(y) - y_mat <- y_processed$y_proc + y_processed <- process_y_sequential(y) + + # Determine is_classification, class_levels, and num_classes is_classification <- y_processed$is_classification class_levels <- y_processed$class_levels num_classes <- y_processed$num_classes + y_mat <- y_processed$y_proc # Determine default compile arguments based on mode default_loss <- if (is_classification) { @@ -93,6 +110,7 @@ build_and_compile_sequential_model <- function( } # --- 3. Model Compilation --- + # Collect all arguments starting with "compile_" from `...` compile_args <- collect_compile_args( all_args, learn_rate, From 2b29c8756b6d8d4439432baa0344c7c9bf38def2 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:33:29 +0200 Subject: [PATCH 03/32] Ensuring build and compile functional uses the correct process and handles multi input and output --- R/build_and_compile_model.R | 217 ++++++++++++++++++++++++++++++------ 1 file changed, 185 insertions(+), 32 deletions(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index 3e25486..0d84efb 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -122,6 +122,24 @@ build_and_compile_sequential_model <- function( return(model) } +#' Build and Compile a Keras Functional Model +#' +#' @description +#' This internal helper function constructs and compiles a Keras functional model +#' based on a list of layer blocks and other parameters. It handles data +#' processing, dynamic architecture construction (including multiple inputs and +#' branches), and model compilation. +#' +#' @param x A data frame or matrix of predictors. For multiple inputs, this is +#' often a data frame with list-columns. +#' @param y A vector or data frame of outcomes. Can handle multiple outputs if +#' provided as a data frame with multiple columns. +#' @param layer_blocks A named list of functions that define the building blocks +#' of the model graph. Connections are defined by referencing other block names. +#' @param ... Additional arguments passed to the function, including layer +#' hyperparameters, repetition counts for blocks, and compile/fit arguments. +#' +#' @return A compiled Keras model object. #' @noRd build_and_compile_functional_model <- function( x, @@ -135,44 +153,117 @@ build_and_compile_functional_model <- function( verbose <- all_args$verbose %||% 0 # Process x input - x_processed <- process_x(x) + x_processed <- process_x_functional(x) x_proc <- x_processed$x_proc input_shape <- x_processed$input_shape # Process y input - y_processed <- process_y(y) - y_mat <- y_processed$y_proc - is_classification <- y_processed$is_classification - class_levels <- y_processed$class_levels - num_classes <- y_processed$num_classes + y_processed <- process_y_functional(y) # Determine default compile arguments based on mode - default_loss <- if (is_classification) { - if (num_classes > 2) { - "categorical_crossentropy" - } else { - "binary_crossentropy" + default_losses <- list() + default_metrics_list <- list() + + # Check if y_processed$y_proc is a list (indicating multiple outputs) + if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) { + # Multiple outputs + for (output_name in names(y_processed$y_proc)) { + # We need to determine is_classification and num_classes for each output + # based on the class_levels for that output. + current_class_levels <- y_processed$class_levels[[output_name]] + current_is_classification <- !is.null(current_class_levels) && + length(current_class_levels) > 0 + current_num_classes <- if (current_is_classification) { + length(current_class_levels) + } else { + NULL + } + + default_losses[[output_name]] <- if (current_is_classification) { + if (current_num_classes > 2) { + "categorical_crossentropy" + } else { + "binary_crossentropy" + } + } else { + "mean_squared_error" + } + default_metrics_list[[output_name]] <- if (current_is_classification) { + "accuracy" + } else { + "mean_absolute_error" + } } } else { - "mean_squared_error" - } - default_metrics <- if (is_classification) { - "accuracy" - } else { - "mean_absolute_error" + # Single output case + # Determine is_classification and num_classes from the top-level class_levels + is_classification <- !is.null(y_processed$class_levels) && + length(y_processed$class_levels) > 0 + num_classes <- if (is_classification) { + length(y_processed$class_levels) + } else { + NULL + } + + default_losses <- if (is_classification) { + if (num_classes > 2) { + "categorical_crossentropy" + } else { + "binary_crossentropy" + } + } else { + "mean_squared_error" + } + default_metrics_list <- if (is_classification) { + "accuracy" + } else { + "mean_absolute_error" + } } # --- 2. Dynamic Model Architecture Construction (DIFFERENT from sequential) --- # Create a list to store the output tensors of each block. The names of the # list elements correspond to the block names. block_outputs <- list() - # The first block MUST be the input layer and MUST NOT have `input_from`. - first_block_name <- names(layer_blocks)[1] - first_block_fn <- layer_blocks[[first_block_name]] - block_outputs[[first_block_name]] <- first_block_fn(input_shape = input_shape) + model_input_tensors <- list() # To collect all input tensors for keras_model + + # Identify and process input layers based on names matching input_shape + # This assumes that if input_shape is a named list, the corresponding + # input blocks in layer_blocks will have matching names. + if (is.list(input_shape) && !is.null(names(input_shape))) { + input_block_names_in_spec <- intersect( + names(layer_blocks), + names(input_shape) + ) + + if (length(input_block_names_in_spec) != length(input_shape)) { + stop( + "Mismatch between named inputs from process_x and named input blocks in layer_blocks. ", + "Ensure all processed inputs have a corresponding named input block in your model specification." + ) + } + + for (block_name in input_block_names_in_spec) { + block_fn <- layer_blocks[[block_name]] + current_input_tensor <- block_fn(input_shape = input_shape[[block_name]]) + block_outputs[[block_name]] <- current_input_tensor + model_input_tensors[[block_name]] <- current_input_tensor + } + remaining_layer_blocks_names <- names(layer_blocks)[ + !(names(layer_blocks) %in% input_block_names_in_spec) + ] + } else { + # Single input case (original logic, but now also collecting for model_input_tensors) + first_block_name <- names(layer_blocks)[1] + first_block_fn <- layer_blocks[[first_block_name]] + current_input_tensor <- first_block_fn(input_shape = input_shape) + block_outputs[[first_block_name]] <- current_input_tensor + model_input_tensors[[first_block_name]] <- current_input_tensor + remaining_layer_blocks_names <- names(layer_blocks)[-1] + } # Iterate through the remaining blocks, connecting and repeating them as needed. - for (block_name in names(layer_blocks)[-1]) { + for (block_name in remaining_layer_blocks_names) { block_fn <- layer_blocks[[block_name]] block_fmls <- rlang::fn_fmls(block_fn) block_fml_names <- names(block_fmls) @@ -207,8 +298,42 @@ build_and_compile_functional_model <- function( ) # Add special engine-supplied arguments if the block can accept them - if (is_classification && "num_classes" %in% block_fml_names) { - block_hyperparams$num_classes <- num_classes + # Add special engine-supplied arguments if the block can accept them + # This is primarily for output layers that might need num_classes + if ("num_classes" %in% block_fml_names) { + # Check if this block is an output block and if it's a classification task + if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) { # Multi-output case + # Find the corresponding output in y_processed based on block_name + y_names <- names(y_processed$y_proc) + # If there is only one output, and this block is named 'output', + # connect them automatically. + if (length(y_names) == 1 && block_name == "output") { + y_name <- y_names[1] + is_cls <- !is.null(y_processed$class_levels[[y_name]]) && + length(y_processed$class_levels[[y_name]]) > 0 + if (is_cls) { + block_hyperparams$num_classes <- length(y_processed$class_levels[[y_name]]) + } + } else if (block_name %in% y_names) { + # Standard case: block name matches an output name + current_y_info <- list( + is_classification = !is.null(y_processed$class_levels[[block_name]]) && + length(y_processed$class_levels[[block_name]]) > 0, + num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { + length(y_processed$class_levels[[block_name]]) + } else { + NULL + } + ) + if (current_y_info$is_classification) { + block_hyperparams$num_classes <- current_y_info$num_classes + } + } + } else { # Single output case + if (is_classification) { + block_hyperparams$num_classes <- num_classes + } + } } # --- Get Input Tensors for this block --- @@ -250,14 +375,42 @@ build_and_compile_functional_model <- function( block_outputs[[block_name]] <- current_tensor } - # The last layer must be named 'output' - output_tensor <- block_outputs[["output"]] - if (is.null(output_tensor)) { - stop("An 'output' block must be defined in layer_blocks.") + # The last layer must be named 'output' or match the names of y_processed outputs + final_output_tensors <- list() + + # Check if y_processed$y_proc is a named list, indicating multiple outputs) + if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) { + # Multiple outputs + for (output_name in names(y_processed$y_proc)) { + # Iterate over the names of the actual outputs + if (is.null(block_outputs[[output_name]])) { + stop(paste0( + "An output block named '", + output_name, + "' must be defined in layer_blocks for multi-output models." + )) + } + final_output_tensors[[output_name]] <- block_outputs[[output_name]] + } + } else { + # Single output case + output_tensor <- block_outputs[["output"]] + if (is.null(output_tensor)) { + stop("An 'output' block must be defined in layer_blocks.") + } + final_output_tensors <- output_tensor + } + + # If there's only one input, it shouldn't be a list for keras_model + final_model_inputs <- if (length(model_input_tensors) == 1) { + model_input_tensors[[1]] + } else { + model_input_tensors } + model <- keras3::keras_model( - inputs = block_outputs[[first_block_name]], - outputs = output_tensor + inputs = final_model_inputs, + outputs = final_output_tensors # This will now be a list if multiple outputs ) # --- 3. Model Compilation --- @@ -265,8 +418,8 @@ build_and_compile_functional_model <- function( compile_args <- collect_compile_args( all_args, learn_rate, - default_loss, - default_metrics + default_losses, + default_metrics_list ) rlang::exec(keras3::compile, model, !!!compile_args) From e362717a67f1f88258091a135c204cde9dffc5da Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:34:09 +0200 Subject: [PATCH 04/32] Modifying collect compile args to handle multi loss and metrics --- R/generic_fit_helpers.R | 58 +++++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 85070d1..43ffa27 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -14,8 +14,8 @@ #' #' @param all_args The list of all arguments passed to the fitting function's `...`. #' @param learn_rate The top-level `learn_rate` parameter. -#' @param default_loss The default loss function to use if not provided. -#' @param default_metrics The default metric(s) to use if not provided. +#' @param default_loss The default loss function to use if not provided. Can be a single value or a named list. +#' @param default_metrics The default metric(s) to use if not provided. Can be a single value or a named list of vectors/single values. #' @return A named list of arguments ready to be passed to `keras3::compile()`. #' @noRd collect_compile_args <- function( @@ -53,19 +53,49 @@ collect_compile_args <- function( ) } - # Resolve loss: use user-provided, otherwise default. Resolve string if needed. - loss_arg <- user_compile_args$loss %||% default_loss - if (is.character(loss_arg)) { - final_compile_args$loss <- get_keras_object(loss_arg, "loss") - } else { - final_compile_args$loss <- loss_arg + # Handle loss: can be single or multiple outputs + if (is.list(default_loss) && !is.null(names(default_loss))) { # Multiple outputs + # User can provide a single loss for all outputs, or a named list + loss_arg <- user_compile_args$loss %||% default_loss + if (is.character(loss_arg) && length(loss_arg) == 1) { # Single loss string for all outputs + final_compile_args$loss <- get_keras_object(loss_arg, "loss") + } else if (is.list(loss_arg) && !is.null(names(loss_arg))) { # Named list of losses + final_compile_args$loss <- lapply(loss_arg, function(l) { + if (is.character(l)) get_keras_object(l, "loss") else l + }) + } else { + stop("For multiple outputs, 'compile_loss' must be a single string or a named list of losses.") + } + } else { # Single output + loss_arg <- user_compile_args$loss %||% default_loss + if (is.character(loss_arg)) { + final_compile_args$loss <- get_keras_object(loss_arg, "loss") + } else { + final_compile_args$loss <- loss_arg + } } - # Resolve metrics: user‐supplied or default - metrics_arg <- user_compile_args$metrics %||% default_metrics - # Keras' `compile()` can handle a single string or a list/vector of strings. - # This correctly passes along either the default string or a user-provided vector. - final_compile_args$metrics <- metrics_arg + # Handle metrics: can be single or multiple outputs + if (is.list(default_metrics) && !is.null(names(default_metrics))) { # Multiple outputs + # User can provide a single metric for all outputs, or a named list + metrics_arg <- user_compile_args$metrics %||% default_metrics + if (is.character(metrics_arg) && length(metrics_arg) == 1) { # Single metric string for all outputs + final_compile_args$metrics <- get_keras_object(metrics_arg, "metric") + } else if (is.list(metrics_arg) && !is.null(names(metrics_arg))) { # Named list of metrics + final_compile_args$metrics <- lapply(metrics_arg, function(m) { + if (is.character(m)) get_keras_object(m, "metric") else m + }) + } else { + stop("For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics.") + } + } else { # Single output + metrics_arg <- user_compile_args$metrics %||% default_metrics + if (is.character(metrics_arg)) { + final_compile_args$metrics <- lapply(metrics_arg, get_keras_object, "metric") + } else { + final_compile_args$metrics <- metrics_arg + } + } # Add any other user-provided compile arguments (e.g., `weighted_metrics`) other_args <- user_compile_args[ @@ -133,4 +163,4 @@ collect_fit_args <- function( ) ] merged_args -} +} \ No newline at end of file From 6f65be4480a6c5bb36855be5742c4d3d25c5cc63 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:34:52 +0200 Subject: [PATCH 05/32] Ensuring correct use of new process in corresponding fit engines --- R/generic_functional_fit.R | 8 +++++--- R/generic_sequential_fit.R | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/R/generic_functional_fit.R b/R/generic_functional_fit.R index c1e7649..2be84d9 100644 --- a/R/generic_functional_fit.R +++ b/R/generic_functional_fit.R @@ -90,9 +90,9 @@ generic_functional_fit <- function( # --- 2. Model Fitting --- all_args <- list(...) verbose <- all_args$verbose %||% 0 - x_processed <- process_x(x) + x_processed <- process_x_functional(x) x_proc <- x_processed$x_proc - y_processed <- process_y(y) + y_processed <- process_y_functional(y) y_mat <- y_processed$y_proc fit_args <- collect_fit_args( @@ -109,6 +109,8 @@ generic_functional_fit <- function( list( fit = model, # The raw Keras model object history = history, # The training history - lvl = y_processed$class_levels # Factor levels for classification, NULL for regression + lvl = y_processed$class_levels, # Factor levels for classification, NULL for regression + process_x = process_x_functional, + process_y = process_y_functional ) } diff --git a/R/generic_sequential_fit.R b/R/generic_sequential_fit.R index 42d7107..46d0f20 100644 --- a/R/generic_sequential_fit.R +++ b/R/generic_sequential_fit.R @@ -90,9 +90,9 @@ generic_sequential_fit <- function( # --- 2. Model Fitting --- all_args <- list(...) verbose <- all_args$verbose %||% 0 - x_processed <- process_x(x) + x_processed <- process_x_sequential(x) x_proc <- x_processed$x_proc - y_processed <- process_y(y) + y_processed <- process_y_sequential(y) y_mat <- y_processed$y_proc fit_args <- collect_fit_args( @@ -109,6 +109,8 @@ generic_sequential_fit <- function( list( fit = model, # The raw Keras model object history = history, # The training history - lvl = y_processed$class_levels # Factor levels for classification, NULL for regression + lvl = y_processed$class_levels, # Factor levels for classification, NULL for regression + process_x = process_x_sequential, + process_y = process_y_sequential ) } From 912878906f687d7d8e9239721281156f09188147 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:35:17 +0200 Subject: [PATCH 06/32] Ensuring evaluate retireve process from fit --- R/keras_tools.R | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/R/keras_tools.R b/R/keras_tools.R index f7bc77e..0877d38 100644 --- a/R/keras_tools.R +++ b/R/keras_tools.R @@ -60,14 +60,28 @@ #' } #' @export keras_evaluate <- function(object, x, y = NULL, ...) { - # 1. Preprocess predictor data (x) - x_processed <- process_x(x) + # 1. Get the correct processing functions from the fit object + process_x_fun <- object$fit$process_x + process_y_fun <- object$fit$process_y + + if (is.null(process_x_fun) || is.null(process_y_fun)) { + stop( + "Could not find processing functions in the model fit object. ", + "Please ensure the model was fitted with a recent version of kerasnip.", + call. = FALSE + ) + } + + # 2. Preprocess predictor data (x) + x_processed <- process_x_fun(x) x_proc <- x_processed$x_proc - # 2. Preprocess outcome data (y) + # 3. Preprocess outcome data (y) y_proc <- NULL if (!is.null(y)) { - y_processed <- process_y( + # Note: For evaluation, we pass the class levels from the trained model + # to ensure consistent encoding of the new data. + y_processed <- process_y_fun( y, is_classification = !is.null(object$fit$lvl), class_levels = object$fit$lvl @@ -75,7 +89,7 @@ keras_evaluate <- function(object, x, y = NULL, ...) { y_proc <- y_processed$y_proc } - # 3. Call the underlying Keras evaluate method + # 4. Call the underlying Keras evaluate method keras_model <- object$fit$fit keras3::evaluate(keras_model, x = x_proc, y = y_proc, ...) } From ff683458086dcdbf1acfcd8a3f14648c994d8e8a Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:35:59 +0200 Subject: [PATCH 07/32] Ensuring predict uses correct process and handles multi output correctly --- R/register_fit_predict.R | 107 ++++++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 19 deletions(-) diff --git a/R/register_fit_predict.R b/R/register_fit_predict.R index 2fa6c48..d362ff0 100644 --- a/R/register_fit_predict.R +++ b/R/register_fit_predict.R @@ -57,7 +57,11 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) { func = c(fun = "predict"), args = list( object = rlang::expr(object$fit$fit), - x = rlang::expr(process_x(new_data)$x_proc) + x = if (functional) { + rlang::expr(process_x_functional(new_data)$x_proc) + } else { + rlang::expr(process_x_sequential(new_data)$x_proc) + } ) ) ) @@ -74,7 +78,11 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) { func = c(fun = "predict"), args = list( object = rlang::expr(object$fit$fit), - x = rlang::expr(process_x(new_data)$x_proc) + x = if (functional) { + rlang::expr(process_x_functional(new_data)$x_proc) + } else { + rlang::expr(process_x_sequential(new_data)$x_proc) + } ) ) ) @@ -89,14 +97,18 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) { func = c(fun = "predict"), args = list( object = rlang::expr(object$fit$fit), - x = rlang::expr(process_x(new_data)$x_proc) + x = if (functional) { + rlang::expr(process_x_functional(new_data)$x_proc) + } else { + rlang::expr(process_x_sequential(new_data)$x_proc) + } ) ) ) } } -#' Post-process Keras Numeric Predictions +##' Post-process Keras Numeric Predictions #' #' @description #' Formats raw numeric predictions from a Keras model into a tibble with the @@ -110,7 +122,22 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) { #' @return A tibble with a `.pred` column. #' @noRd keras_postprocess_numeric <- function(results, object) { - tibble::tibble(.pred = as.vector(results)) + if (is.list(results) && !is.null(names(results))) { + # Multi-output case: results is a named list of arrays/matrices + # Combine them into a single tibble with appropriate column names + combined_preds <- tibble::as_tibble(results) + # Rename columns to .pred_output_name if there are multiple outputs + if (length(results) > 1) { + colnames(combined_preds) <- paste0(".pred_", names(results)) + } else { + # If only one output, but still a list, name it .pred + colnames(combined_preds) <- ".pred" + } + return(combined_preds) + } else { + # Single output case: results is a matrix/array + tibble::tibble(.pred = as.vector(results)) + } } #' Post-process Keras Probability Predictions @@ -127,9 +154,25 @@ keras_postprocess_numeric <- function(results, object) { #' @return A tibble with named columns for each class probability. #' @noRd keras_postprocess_probs <- function(results, object) { - # The levels are now nested inside the fit object - colnames(results) <- object$fit$lvl - tibble::as_tibble(results) + if (is.list(results) && !is.null(names(results))) { + # Multi-output case: results is a named list of arrays/matrices + combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { + lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels + if (is.null(lvls)) { + # Fallback if levels are not specifically named for this output + lvls <- paste0("class", 1:ncol(res)) + } + colnames(res) <- lvls + tibble::as_tibble(res, .name_repair = "unique") %>% + dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) + }) + return(combined_preds) + } else { + # Single output case: results is a matrix/array + # The levels are now nested inside the fit object + colnames(results) <- object$fit$lvl + tibble::as_tibble(results) + } } #' Post-process Keras Class Predictions @@ -147,17 +190,43 @@ keras_postprocess_probs <- function(results, object) { #' @return A tibble with a `.pred_class` column containing factor predictions. #' @noRd keras_postprocess_classes <- function(results, object) { - # The levels are now nested inside the fit object - lvls <- object$fit$lvl - if (ncol(results) == 1) { - # Binary classification - pred_class <- ifelse(results[, 1] > 0.5, lvls[2], lvls[1]) - pred_class <- factor(pred_class, levels = lvls) + if (is.list(results) && !is.null(names(results))) { + # Multi-output case: results is a named list of arrays/matrices + combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { + lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels + if (is.null(lvls)) { + # Fallback if levels are not specifically named for this output + lvls <- paste0("class", 1:ncol(res)) # This might not be correct for classes, but a placeholder + } + + if (ncol(res) == 1) { + # Binary classification + pred_class <- ifelse(res[, 1] > 0.5, lvls[2], lvls[1]) + pred_class <- factor(pred_class, levels = lvls) + } else { + # Multiclass classification + pred_class_int <- apply(res, 1, which.max) + pred_class <- lvls[pred_class_int] + pred_class <- factor(pred_class, levels = lvls) + } + tibble::tibble(.pred_class = pred_class) %>% + dplyr::rename_with(~ paste0(".pred_class_", name)) + }) + return(combined_preds) } else { - # Multiclass classification - pred_class_int <- apply(results, 1, which.max) - pred_class <- lvls[pred_class_int] - pred_class <- factor(pred_class, levels = lvls) + # Single output case: results is a matrix/array + # The levels are now nested inside the fit object + lvls <- object$fit$lvl + if (ncol(results) == 1) { + # Binary classification + pred_class <- ifelse(results[, 1] > 0.5, lvls[2], lvls[1]) + pred_class <- factor(pred_class, levels = lvls) + } else { + # Multiclass classification + pred_class_int <- apply(results, 1, which.max) + pred_class <- lvls[pred_class_int] + pred_class <- factor(pred_class, levels = lvls) + } + tibble::tibble(.pred_class = pred_class) } - tibble::tibble(.pred_class = pred_class) } From b7d1844ccc65031322d7720092a5c560f1bfd497 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:36:39 +0200 Subject: [PATCH 08/32] Updating documentation --- NAMESPACE | 6 ++-- man/process_x_functional.Rd | 24 +++++++++++++ man/{process_x.Rd => process_x_sequential.Rd} | 6 ++-- man/process_y_functional.Rd | 35 +++++++++++++++++++ man/{process_y.Rd => process_y_sequential.Rd} | 6 ++-- 5 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 man/process_x_functional.Rd rename man/{process_x.Rd => process_x_sequential.Rd} (87%) create mode 100644 man/process_y_functional.Rd rename man/{process_y.Rd => process_y_sequential.Rd} (90%) diff --git a/NAMESPACE b/NAMESPACE index 21f5c93..6f66623 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -16,8 +16,10 @@ export(keras_metrics) export(keras_optimizers) export(loss_function_keras) export(optimizer_function) -export(process_x) -export(process_y) +export(process_x_functional) +export(process_x_sequential) +export(process_y_functional) +export(process_y_sequential) export(register_keras_loss) export(register_keras_metric) export(register_keras_optimizer) diff --git a/man/process_x_functional.Rd b/man/process_x_functional.Rd new file mode 100644 index 0000000..6b05569 --- /dev/null +++ b/man/process_x_functional.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{process_x_functional} +\alias{process_x_functional} +\title{Process Predictor Input for Keras (Functional API)} +\usage{ +process_x_functional(x) +} +\arguments{ +\item{x}{A data frame or matrix of predictors.} +} +\value{ +A list containing: +\itemize{ +\item \code{x_proc}: The processed predictor data (matrix or array, or list of arrays). +\item \code{input_shape}: The determined input shape(s) for the Keras model. +} +} +\description{ +Preprocesses predictor data (\code{x}) into a format suitable for Keras models +built with the Functional API. Handles both tabular data and list-columns +of arrays (e.g., for images), supporting multiple inputs. +} +\keyword{internal} diff --git a/man/process_x.Rd b/man/process_x_sequential.Rd similarity index 87% rename from man/process_x.Rd rename to man/process_x_sequential.Rd index f464bbc..4a8059b 100644 --- a/man/process_x.Rd +++ b/man/process_x_sequential.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/utils.R -\name{process_x} -\alias{process_x} +\name{process_x_sequential} +\alias{process_x_sequential} \title{Process Predictor Input for Keras} \usage{ -process_x(x) +process_x_sequential(x) } \arguments{ \item{x}{A data frame or matrix of predictors.} diff --git a/man/process_y_functional.Rd b/man/process_y_functional.Rd new file mode 100644 index 0000000..8294f38 --- /dev/null +++ b/man/process_y_functional.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{process_y_functional} +\alias{process_y_functional} +\title{Process Outcome Input for Keras (Functional API)} +\usage{ +process_y_functional(y, is_classification = NULL, class_levels = NULL) +} +\arguments{ +\item{y}{A vector or data frame of outcomes.} + +\item{is_classification}{Logical, optional. If \code{TRUE}, treats \code{y} as +classification. If \code{FALSE}, treats as regression. If \code{NULL} (default), +it's determined from \code{is.factor(y)}.} + +\item{class_levels}{Character vector, optional. The factor levels for +classification outcomes. If \code{NULL} (default), determined from \code{levels(y)}.} +} +\value{ +A list containing: +\itemize{ +\item \code{y_proc}: The processed outcome data (matrix or one-hot encoded array, +or list of these for multiple outputs). +\item \code{is_classification}: Logical, indicating if \code{y} was treated as classification. +\item \code{num_classes}: Integer, the number of classes for classification, or \code{NULL}. +\item \code{class_levels}: Character vector, the factor levels for classification, or \code{NULL}. +} +} +\description{ +Preprocesses outcome data (\code{y}) into a format suitable for Keras models +built with the Functional API. Handles both regression (numeric) and +classification (factor) outcomes, including one-hot encoding for classification, +and supports multiple outputs. +} +\keyword{internal} diff --git a/man/process_y.Rd b/man/process_y_sequential.Rd similarity index 90% rename from man/process_y.Rd rename to man/process_y_sequential.Rd index 4d1187d..05ee206 100644 --- a/man/process_y.Rd +++ b/man/process_y_sequential.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/utils.R -\name{process_y} -\alias{process_y} +\name{process_y_sequential} +\alias{process_y_sequential} \title{Process Outcome Input for Keras} \usage{ -process_y(y, is_classification = NULL, class_levels = NULL) +process_y_sequential(y, is_classification = NULL, class_levels = NULL) } \arguments{ \item{y}{A vector of outcomes.} From 2d43e244b9f85e9797aa072ebb2ea50149afd969 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:37:06 +0200 Subject: [PATCH 09/32] Refactoring and adding tests --- tests/testthat/helper_keras.R | 1 + tests/testthat/test_compile_keras_grid.R | 61 +++++-- ...ional.R => test_e2e_func_classification.R} | 149 ++++++--------- tests/testthat/test_e2e_func_regression.R | 172 ++++++++++++++++++ ...cation.R => test_e2e_seq_classification.R} | 0 ...regression.R => test_e2e_seq_regression.R} | 0 6 files changed, 273 insertions(+), 110 deletions(-) rename tests/testthat/{test_e2e_functional.R => test_e2e_func_classification.R} (59%) create mode 100644 tests/testthat/test_e2e_func_regression.R rename tests/testthat/{test_e2e_classification.R => test_e2e_seq_classification.R} (100%) rename tests/testthat/{test_e2e_regression.R => test_e2e_seq_regression.R} (100%) diff --git a/tests/testthat/helper_keras.R b/tests/testthat/helper_keras.R index fc1aded..b67d286 100644 --- a/tests/testthat/helper_keras.R +++ b/tests/testthat/helper_keras.R @@ -7,6 +7,7 @@ library(rsample) library(dials) library(tune) library(purrr) +library(dplyr) skip_if_no_keras <- function() { testthat::skip_if_not_installed("keras3") diff --git a/tests/testthat/test_compile_keras_grid.R b/tests/testthat/test_compile_keras_grid.R index 5045cf3..b4ac6ed 100644 --- a/tests/testthat/test_compile_keras_grid.R +++ b/tests/testthat/test_compile_keras_grid.R @@ -1,6 +1,7 @@ # --- Test Data --- x_train <- as.matrix(iris[, 1:4]) y_train <- iris$Species +train_df <- tibble(x = I(x_train), y = y_train) # --- Tests --- test_that("compile_keras_grid works for sequential models", { @@ -36,7 +37,12 @@ test_that("compile_keras_grid works for sequential models", { learn_rate = c(0.01, 0.001) ) - results <- compile_keras_grid(spec, grid, x_train, y_train) + results <- compile_keras_grid( + spec, + grid, + select(train_df, x), + select(train_df, y) + ) expect_s3_class(results, "tbl_df") expect_equal(nrow(results), 2) @@ -63,32 +69,54 @@ test_that("compile_keras_grid works for functional models", { model_name <- "test_func_spec_compile" on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) + input_block <- function(input_shape) { + keras3::layer_input(shape = input_shape, name = "x") + } + + dense_block <- function(tensor, units = 32) { + tensor |> keras3::layer_dense(units = units, activation = "relu") + } + + output_block <- function(tensor, num_classes) { + tensor |> + keras3::layer_dense( + units = num_classes, + activation = "softmax", + name = "y" + ) + } + create_keras_functional_spec( model_name = model_name, mode = "classification", layer_blocks = list( - input = function(input_shape) { - keras3::layer_input(shape = input_shape) - }, - dense = function(input, units = 32) { - input |> keras3::layer_dense(units = units, activation = "relu") - }, - output = function(dense, num_classes) { - dense |> - keras3::layer_dense(units = num_classes, activation = "softmax") - } + input = input_block, + dense = inp_spec(dense_block, "input"), + output = inp_spec(output_block, "dense") ) ) spec <- test_func_spec_compile() |> set_engine("keras") + rec <- recipe(y ~ x, data = train_df) # Recipe for two outputs + wf <- workflow() |> + add_recipe(rec) |> + add_model(spec) + + fit_obj <- fit(wf, data = train_df) + grid <- tibble::tibble( dense_units = c(16, 32), learn_rate = c(0.01, 0.001) ) - results <- compile_keras_grid(spec, grid, x_train, y_train) + results <- compile_keras_grid( + spec, + grid, + select(train_df, x), + select(train_df, y) + ) expect_s3_class(results, "tbl_df") expect_equal(nrow(results), 2) @@ -120,7 +148,7 @@ test_that("compile_keras_grid handles errors gracefully", { mode = "classification", layer_blocks = list( input = function(input_shape) { - keras3::layer_input(shape = input_shape) + keras3::layer_input(shape = input_shape, name = "x") }, dense1 = function(input, units = 32) { input |> keras3::layer_dense(units = units, activation = "relu") @@ -142,7 +170,12 @@ test_that("compile_keras_grid handles errors gracefully", { grid <- tibble::tibble(dense1_units = 16) expect_warning( - results <- compile_keras_grid(spec, grid, x_train, y_train), + results <- compile_keras_grid( + spec, + grid, + select(train_df, x), + select(train_df, y) + ), "Block 'dense2' has no inputs from other blocks." ) diff --git a/tests/testthat/test_e2e_functional.R b/tests/testthat/test_e2e_func_classification.R similarity index 59% rename from tests/testthat/test_e2e_functional.R rename to tests/testthat/test_e2e_func_classification.R index 2b90eb6..6a245e9 100644 --- a/tests/testthat/test_e2e_functional.R +++ b/tests/testthat/test_e2e_func_classification.R @@ -1,57 +1,3 @@ -test_that("E2E: Functional spec (regression) works", { - skip_if_no_keras() - - # Define blocks for a simple forked functional model - input_block <- function(input_shape) keras3::layer_input(shape = input_shape) - path_block <- function(tensor, units = 8) { - tensor |> keras3::layer_dense(units = units, activation = "relu") - } - concat_block <- function(input_a, input_b) { - keras3::layer_concatenate(list(input_a, input_b)) - } - output_block_reg <- function(tensor) keras3::layer_dense(tensor, units = 1) - - model_name <- "e2e_func_reg" - on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) - - # Create a spec with two parallel paths that are then concatenated - create_keras_functional_spec( - model_name = model_name, - layer_blocks = list( - main_input = input_block, - path_a = inp_spec(path_block, "main_input"), - path_b = inp_spec(path_block, "main_input"), - concatenated = inp_spec( - concat_block, - c(path_a = "input_a", path_b = "input_b") - ), - output = inp_spec(output_block_reg, "concatenated") - ), - mode = "regression" - ) - - spec <- e2e_func_reg( - path_a_units = 32, - path_b_units = 16, - fit_epochs = 2 - ) |> - set_engine("keras") - - data <- mtcars - rec <- recipe(mpg ~ ., data = data) - wf <- workflows::workflow(rec, spec) - - expect_no_error(fit_obj <- parsnip::fit(wf, data = data)) - expect_s3_class(fit_obj, "workflow") - - preds <- predict(fit_obj, new_data = data[1:5, ]) - expect_s3_class(preds, "tbl_df") - expect_equal(names(preds), ".pred") - expect_equal(nrow(preds), 5) - expect_true(is.numeric(preds$.pred)) -}) - - test_that("E2E: Functional spec (classification) works", { skip_if_no_keras() @@ -114,7 +60,6 @@ test_that("E2E: Functional spec (classification) works", { expect_true(all(abs(rowSums(preds_prob) - 1) < 1e-5)) }) - test_that("E2E: Functional spec tuning (including repetition) works", { skip_if_no_keras() @@ -154,7 +99,7 @@ test_that("E2E: Functional spec tuning (including repetition) works", { tune_wf <- workflows::workflow(rec, tune_spec) folds <- rsample::vfold_cv(iris, v = 2) - params <- extract_parameter_set_dials(tune_wf) |> + params <- extract_parameter_set_dials(tune_wf) |> update( num_dense_path = num_terms(c(1, 2)), dense_path_units = hidden_units(c(4, 8)) @@ -183,56 +128,68 @@ test_that("E2E: Functional spec tuning (including repetition) works", { expect_true(all(c("num_dense_path", "dense_path_units") %in% names(metrics))) }) -test_that("E2E: Block repetition works for functional models", { +test_that("E2E: Multi-input, single-output functional classification works", { skip_if_no_keras() - - input_block <- function(input_shape) keras3::layer_input(shape = input_shape) - dense_block <- function(tensor, units = 8) { - tensor |> keras3::layer_dense(units = units, activation = "relu") + options(kerasnip.show_removal_messages = FALSE) + on.exit(options(kerasnip.show_removal_messages = TRUE), add = TRUE) + + # Define layer blocks + input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") + input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") + flatten_block <- function(tensor) layer_flatten(tensor) + dense_path <- function(tensor, units = 16) { + tensor |> layer_dense(units = units, activation = "relu") + } + concat_block <- function(in_1, in_2) layer_concatenate(list(in_1, in_2)) + output_block_class <- function(tensor, num_classes) { + layer_dense(tensor, units = num_classes, activation = "softmax") } - output_block <- function(tensor) keras3::layer_dense(tensor, units = 1) - model_name <- "e2e_func_repeat" + model_name <- "multi_in_class" on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) create_keras_functional_spec( model_name = model_name, layer_blocks = list( - main_input = input_block, - dense_path = inp_spec(dense_block, "main_input"), - output = inp_spec(output_block, "dense_path") + input_a = input_block_1, + input_b = input_block_2, + flatten_a = inp_spec(flatten_block, "input_a"), + flatten_b = inp_spec(flatten_block, "input_b"), + path_a = inp_spec(dense_path, "flatten_a"), + path_b = inp_spec(dense_path, "flatten_b"), + concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), + output = inp_spec(output_block_class, "concatenated") ), - mode = "regression" + mode = "classification" ) - # --- Test with 1 repetition --- - spec_1 <- e2e_func_repeat(num_dense_path = 1, fit_epochs = 1) |> - set_engine("keras") - fit_1 <- fit(spec_1, mpg ~ ., data = mtcars) - model_1_layers <- fit_1 |> - extract_keras_model() |> - pluck("layers") + spec <- multi_in_class(fit_epochs = 2) |> set_engine("keras") - # Expect 3 layers: Input, Dense, Output - expect_equal(length(model_1_layers), 3) + # Prepare dummy data + set.seed(123) + x1 <- matrix(rnorm(100 * 5), ncol = 5) + x2 <- matrix(rnorm(100 * 3), ncol = 3) + y <- factor(sample(c("a", "b"), 100, replace = TRUE)) - # --- Test with 2 repetitions --- - spec_2 <- e2e_func_repeat(num_dense_path = 2, fit_epochs = 1) |> - set_engine("keras") - fit_2 <- fit(spec_2, mpg ~ ., data = mtcars) - model_2_layers <- fit_2 |> - extract_keras_model() |> - pluck("layers") - # Expect 4 layers: Input, Dense, Dense, Output - expect_equal(length(model_2_layers), 4) - - # --- Test with 0 repetitions --- - spec_3 <- e2e_func_repeat(num_dense_path = 0, fit_epochs = 1) |> - set_engine("keras") - fit_3 <- fit(spec_3, mpg ~ ., data = mtcars) - model_3_layers <- fit_3 |> - extract_keras_model() |> - pluck("layers") - # Expect 2 layers: Input, Output - expect_equal(length(model_3_layers), 2) -}) + train_df <- tibble::tibble( + input_a = lapply(seq_len(nrow(x1)), function(i) x1[i, , drop = FALSE]), + input_b = lapply(seq_len(nrow(x2)), function(i) x2[i, , drop = FALSE]), + outcome = y + ) + + rec <- recipe(outcome ~ input_a + input_b, data = train_df) + wf <- workflows::workflow(rec, spec) + + expect_no_error(fit_obj <- parsnip::fit(wf, data = train_df)) + + new_data_df <- tibble::tibble( + input_a = lapply(seq_len(5), function(i) matrix(rnorm(5), ncol = 5)), + input_b = lapply(seq_len(5), function(i) matrix(rnorm(3), ncol = 3)) + ) + preds <- predict(fit_obj, new_data = new_data_df) + + expect_s3_class(preds, "tbl_df") + expect_equal(names(preds), c(".pred_class")) + expect_equal(nrow(preds), 5) + expect_true(is.factor(preds$.pred_class)) +}) \ No newline at end of file diff --git a/tests/testthat/test_e2e_func_regression.R b/tests/testthat/test_e2e_func_regression.R new file mode 100644 index 0000000..653ca37 --- /dev/null +++ b/tests/testthat/test_e2e_func_regression.R @@ -0,0 +1,172 @@ +test_that("E2E: Functional spec (regression) works", { + skip_if_no_keras() + + # Define blocks for a simple forked functional model + input_block <- function(input_shape) keras3::layer_input(shape = input_shape) + path_block <- function(tensor, units = 8) { + tensor |> keras3::layer_dense(units = units, activation = "relu") + } + concat_block <- function(input_a, input_b) { + keras3::layer_concatenate(list(input_a, input_b)) + } + output_block_reg <- function(tensor) keras3::layer_dense(tensor, units = 1) + + model_name <- "e2e_func_reg" + on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) + + # Create a spec with two parallel paths that are then concatenated + create_keras_functional_spec( + model_name = model_name, + layer_blocks = list( + main_input = input_block, + path_a = inp_spec(path_block, "main_input"), + path_b = inp_spec(path_block, "main_input"), + concatenated = inp_spec( + concat_block, + c(path_a = "input_a", path_b = "input_b") + ), + output = inp_spec(output_block_reg, "concatenated") + ), + mode = "regression" + ) + + spec <- e2e_func_reg( + path_a_units = 32, + path_b_units = 16, + fit_epochs = 2 + ) |> + set_engine("keras") + + data <- mtcars + rec <- recipe(mpg ~ ., data = data) + wf <- workflows::workflow(rec, spec) + + expect_no_error(fit_obj <- parsnip::fit(wf, data = data)) + expect_s3_class(fit_obj, "workflow") + + preds <- predict(fit_obj, new_data = data[1:5, ]) + expect_s3_class(preds, "tbl_df") + expect_equal(names(preds), ".pred") + expect_equal(nrow(preds), 5) + expect_true(is.numeric(preds$.pred)) +}) + +test_that("E2E: Block repetition works for functional models", { + skip_if_no_keras() + + input_block <- function(input_shape) keras3::layer_input(shape = input_shape) + dense_block <- function(tensor, units = 8) { + tensor |> keras3::layer_dense(units = units, activation = "relu") + } + output_block <- function(tensor) keras3::layer_dense(tensor, units = 1) + + model_name <- "e2e_func_repeat" + on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) + + create_keras_functional_spec( + model_name = model_name, + layer_blocks = list( + main_input = input_block, + dense_path = inp_spec(dense_block, "main_input"), + output = inp_spec(output_block, "dense_path") + ), + mode = "regression" + ) + + # --- Test with 1 repetition --- + spec_1 <- e2e_func_repeat(num_dense_path = 1, fit_epochs = 1) |> + set_engine("keras") + fit_1 <- fit(spec_1, mpg ~ ., data = mtcars) + model_1_layers <- fit_1 |> + extract_keras_model() |> + pluck("layers") + + # Expect 3 layers: Input, Dense, Output + expect_equal(length(model_1_layers), 3) + + # --- Test with 2 repetitions --- + spec_2 <- e2e_func_repeat(num_dense_path = 2, fit_epochs = 1) |> + set_engine("keras") + fit_2 <- fit(spec_2, mpg ~ ., data = mtcars) + model_2_layers <- fit_2 |> + extract_keras_model() |> + pluck("layers") + # Expect 4 layers: Input, Dense, Dense, Output + expect_equal(length(model_2_layers), 4) + + # --- Test with 0 repetitions --- + spec_3 <- e2e_func_repeat(num_dense_path = 0, fit_epochs = 1) |> + set_engine("keras") + fit_3 <- fit(spec_3, mpg ~ ., data = mtcars) + model_3_layers <- fit_3 |> + extract_keras_model() |> + pluck("layers") + # Expect 2 layers: Input, Output + expect_equal(length(model_3_layers), 2) +}) + +test_that("E2E: Multi-input, multi-output functional regression works", { + skip_if_no_keras() + options(kerasnip.show_removal_messages = FALSE) + on.exit(options(kerasnip.show_removal_messages = TRUE), add = TRUE) + + # Define layer blocks + input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") + input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") + dense_path <- function(tensor, units = 16) { + tensor |> layer_dense(units = units, activation = "relu") + } + concat_block <- function(in_1, in_2) layer_concatenate(list(in_1, in_2)) + output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") + output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") + + model_name <- "multi_in_out_reg" + on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) + + create_keras_functional_spec( + model_name = model_name, + layer_blocks = list( + input_a = input_block_1, + input_b = input_block_2, + path_a = inp_spec(dense_path, "input_a"), + path_b = inp_spec(dense_path, "input_b"), + concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), + output_1 = inp_spec(output_block_1, "concatenated"), + output_2 = inp_spec(output_block_2, "concatenated") + ), + mode = "regression" + ) + + spec <- multi_in_out_reg(fit_epochs = 2) |> set_engine("keras") + + # Prepare dummy data + set.seed(123) + x1 <- matrix(rnorm(100 * 5), ncol = 5) + x2 <- matrix(rnorm(100 * 3), ncol = 3) + y1 <- rnorm(100) + y2 <- rnorm(100) + + train_df <- tibble::tibble( + input_a = lapply(seq_len(nrow(x1)), function(i) x1[i, , drop = FALSE]), + input_b = lapply(seq_len(nrow(x2)), function(i) x2[i, , drop = FALSE]), + output_1 = y1, + output_2 = y2 + ) + + rec <- recipe(output_1 + output_2 ~ input_a + input_b, data = train_df) + wf <- workflows::workflow(rec, spec) + + expect_no_error(fit_obj <- parsnip::fit(wf, data = train_df)) + + new_data_df <- tibble::tibble( + input_a = lapply(seq_len(5), function(i) matrix(rnorm(5), ncol = 5)), + input_b = lapply(seq_len(5), function(i) matrix(rnorm(3), ncol = 3)) + ) + preds <- predict(fit_obj, new_data = new_data_df) + + expect_s3_class(preds, "tbl_df") + expect_equal(names(preds), c(".pred_output_1", ".pred_output_2")) + expect_equal(nrow(preds), 5) + expect_true(is.numeric(preds$.pred_output_1)) + expect_true(is.numeric(preds$.pred_output_2)) +}) \ No newline at end of file diff --git a/tests/testthat/test_e2e_classification.R b/tests/testthat/test_e2e_seq_classification.R similarity index 100% rename from tests/testthat/test_e2e_classification.R rename to tests/testthat/test_e2e_seq_classification.R diff --git a/tests/testthat/test_e2e_regression.R b/tests/testthat/test_e2e_seq_regression.R similarity index 100% rename from tests/testthat/test_e2e_regression.R rename to tests/testthat/test_e2e_seq_regression.R From 1039a9b329f2368511c5d45bc4a75046851f325e Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:38:55 +0200 Subject: [PATCH 10/32] Removing sections that exeeds the scope of the package --- vignettes/images/model_plot_shapes_s.png | Bin 0 -> 32669 bytes vignettes/sequential_model.Rmd | 120 +++-------------------- 2 files changed, 12 insertions(+), 108 deletions(-) create mode 100644 vignettes/images/model_plot_shapes_s.png diff --git a/vignettes/images/model_plot_shapes_s.png b/vignettes/images/model_plot_shapes_s.png new file mode 100644 index 0000000000000000000000000000000000000000..0d72c05f1d64d7d724d4a41539182b332c409def GIT binary patch literal 32669 zcmeFZc{r4P_&z+XiV8(amMGaNYnUhvib5r_L<-4nRE#V=p-_@+S)!=O(i9wPl%zu&PG;VRsw;r z@#qoKNdjRd9lsLSt;BcUUpe&%FKf^1>5vG^%zqvgr$ypB>)no+dJqU~o0-2>d~5ex zAP{yEj*<=;c_j|DUp&WkdR}Stn)17W^&uR`u8RcTzpk+2@Rki6PgvFnlDs)1rbynW zjSh0%yBi$*=!mQ?w^x+@@0~|NCN^vC6ndbYBc3qs_hvcKv+a1xcm7{%0Cr_m$^Y>T(|Cj&!AAu^?Y)NF#o{Fa91j2P;qYNYQ zFzpM2#>8t1n(NF9RxD}ol@SPa1yNC4IyyS$1>G$HY`b!7`-DO^A8o!8j;g9oqEQ)+6Cd3sKZ8YFRiG+nce;!Dd@)5{YDJ_%+Kkw`6I~kxu(`i@nrmHix5o>TfIF z#KY4w_O_3&@AObh?!`KRYuD!SL!ZmacJJAfwBNn^X{=>8x7VdieK`K_SM(c-VLs}%P$#tL|@<7cn1%U zXyD9DPq_5-?}33TMFsn>*Y&)v4mM@mInK?@j0y*QxL4Z0;?Y;@1>L1b{t?T zju;r!QNH}Mibct_BQe;lz$LrhzNolZ)~e(fDO^xP$y{DSBIA5{s9E;ap&zeEy?_3g zrRnnvo0a&eaH^=N6c!ea#lLv*LPR%eXUa(FatAqlu$d~Fa_7#SEu#9pZTacxVo7^< z%&#QW37E6&nc*=P+(cOJFALdXED%}rv$rbJHsr&(Umt(wIYy@VKXqu#D7ta`He2i+ zfdlb+oA_0C2R=~uRju=zYPqNx94TsGHte#rFk9sO>E^~SKhL$43iDhUZfp6ev&pdR zYw3NB-&QNGOAU1tCy2f_%TyP+`uC?u+Ad*XukHYLb;B3q)1Cg~f%G6PEU0LT{zKI< z!2NA5Y^dl~OV>3%m- zlIs`Vw!NX$e;JXJz`Uqbi@BeUROz#Cqgv%JRBVtOFO(d;QHWV24V0Jr9o zXU`a{jAUG=yx?#6@18({ z+=O)Szy6|sad7RyrlxH3=0}HCnHQ|wxxBEj;Pvs`UOv^!mHA;SUS=P@wP|vo;d4{A zSH2K6J*uF;7! z%5!XsA3qOW(ONI_psvD8EudwE{*~qO~ zyOzfI`qfZU(x^%4(zo5F9j`9vn07qXizO=N9E+2qQVoX$TW+`Rd&hHPZhfDsAjR_e($FEHfI~Qu6Jb8ao zx~CuvP3T zmQ80$k8&z%XEsg^x2qW>aMWX$-F~Lh)Lj?Pxl4OA%SRW5k*9c35Dk_{Oke{zj4Gr}7 z>yVI@cr;(z*WU7At9TH#iNJ)UgO}z8WhDhKsWWKzMfC8TtE{+n1J-WL zFnoL?S}l)k6?7(mmjR%0p(V$b5 zn06DT1%D>0cpMqqqP1>pAT8OO8AucheN)a#mq8;mTk zv$xgD@%dCbKm3ly9d+$QiL|e;FWs}Q{5DT?f`Uui1OMkNYt|rJS~Gg9$o0woOX8ld z$#h*)EUE00Og}$%l4!MPSwllZZLR)rY_P=XjERR(o%D!iro*J%f zXponaJL&SGVE}1cG*)7shz*kFV~EwRTneK7yf#v-Lwfx9F+-(4O+;Ni$?6%g(M~Ro z8SSh>i4Px!RN5=B3f+&4JUSK>u}YXRI=cVLpD4QC@@Us#VFXuHbXbJJ?P!XKynJM^ zOWPakLNP_f#@gCpTyun?>(AFmWe)f+`NXhD$BwkWGuc|3bRglKa}-Me#|uD*J(&@i^b{*%A9p<@qU$Awnv{tNG zf&bMtFbEG1ANl(DaCs=V))Ym{rn8g*$T&YgkLH0*Xlb@TdAR*u>D-@8+FHOI^L1l0 zV;B7Wmm_xRFka{8j^htcojMi70bD6<@}dnP9>{U%&>`>Q63b$5)W7Z9x9i2qh@`X@ zxUvc1lKV$S{1@l0`0anFdc}MEi?WBG()sh}gVVLccxkszJwN$$V&bx>Xc``+j8Pvq zMe{A*IN660!k6PQJ$!aHST{;+tX@ZOj4COt>fRm5p}A?(HbsZ;Z`^xkifZB&9m$jr zTz9ln!^7D*$6VF)|2g~a>f%storH0gSn|=qF_r=XBa)taqIdBAnU(0n}&rvonh4V6W^7y{s^2sQ&Up| zg!}vQ+lDlQ{)O*hJCyGRFU|C@C((<91An}}w1qS1)~)KRG#zE@kE?|x&c2;M0@!ox zfnlkiuW-hxw7XR{-9;bS9RL%3LP;)_{ieLiE(=wn$wkXc^Q7LszKXOXS(kr*6|xP| zPOV4Fh7}afOs(Zb%8eP(6$+$MUFo!F=~E|O%t&8wbuG32w6^*w78utllrr95=dnI~ zZ?&AYxu&J~{$B1Y;$D{9m0c1OvAv1{0s<@)Y$`blD{bVWrA!&vH?x(JykUEG?T%UFyQ|JaY7C*AJnw>uN%1i|)O9cg@@Tt%ZVX2T|bCA=30nr?O)L z!?9v>a;VGq^aygKZnQ*5+V-G#SEDIG{yi_wzB6_q4}25x_Liv_qHj~M=it2Z=lfCO z-}m=Q)ex;AN7TETMi~s@@=NhZ+AJ(Alnqf*rr+i=${MH_z2jfoux&5pXv$s4svG1v zXgSWB{F>%+30Ty|^;8l=*U2eMO2wmZ!-2mcRgRXA)B4x_Xn*H<;r#jD6kuEN=3Xhk z3zDG%0ghHaPEmXV7)cK7aGW#k{f#kcR?y(5?Q*Co4tdvf%u0b9=zA+*q#Ie-+*OwY9$l}Yd4UPAE z4qFJ%b=?xodd^3A`peDt>-bIo6?+W~43;+A5|^%D$}>||R$j`g%2e|lcIqrSn_9n` zg@tNS$Usu!mN2@f7wua)JUmRCYpkz-b@Hk1^XJd8A{yQkzX7z^`Sw3@N+WMuo*yZd zz1CM9C7ya>X9^pSOz+oxY(-8^&npkQI&U1lxnZAV_?))WpHq*x+xF? zB?bL8DJm+8MDn>$Mr zU|zj?g`_4Y((8mBwjk+&db&11NLu03iX_H${yg#{k| z7gh`NukS|1#Z_InHx7(;N?)>Wj(&P^Zc@@XE4|*nB3#fPTLO62bP^S7fuy#y^vD*qXqF>Ow+?C^gM5&L$Yd_MXEvk&K=l z+n!4f)6Tur@C;8^%pjDr9$8;rULGIzdw5ueEmq1@lXQEl^vV3tV6UBAJ%^qpB_&x~ zTl@SN=WtcAQM|~)E&vK*GpFy|KpOGwLbYt^ov^S2SBD-XBqS&*iakwBD+S;Hv_b;v zkF-T=tl>L1fznFxa<%-60=8}+xwjJWN;NOgk4by>%sAUj5e>IZ)qR@Ytr^1JYo?EX zr+3Q{#RO>VF2dg&@^tR^y-qzl923nDr4 znL%=G1G7sQ8!Nmv{lY1Pe4K}er!ZJeGpbj7*=Oc&K_{B-Y19wf$H$Cv${7EMy*p4- zG*+>L;(W8-J`3r_$Hyl@$r)LI7Id3Oc2CNQRPC|GErLFyAKw473R&12%WP3Q+18=omA9+enF37M@AJVQ{Lgsb7nRC`&$vVLYDfWmS28%k?194LzF*- z+b%CH62F>vcdl1K3AA!^+sm1K=2ePoXNg7S<1a&T_4W$5ScisZ$7{KjoPLY~Isyc+ zvaxwwxgsJguyLwZZRA0tl6!XWrW2JPEZ2VNTp%Cia$ZE#FPd- z^4du)I*1Hd(VUm_`t@r>*SM_w(b+*E7SM1^f+QksY-|ic#j=U-7kjGI*<$Z2fQWk{ zk7`Gv-5tFs$QifzTLXxwry`8<_R3%rW$t;DNBkwwP=#c=rQfXWl6aMFjMVlWJ9cc} zZkq38D)WkQ^qh9;zFP{BB0M~?C{yUM=*`~{bXCG{>l;cE`4*PuNB9mM+Y`CqqVBh_ ze{nuD{@Tmo&eG(xXIn+p)ZPKnU;7fT==1IP05C!`)ry%KV|K)^F_g!Ztu5c4H zZ)y>P)hZzJDZkp)nfZCELb6s2N}YAhgMB`Vab>=uz(O=@9e@%*V|T#+eAmvTw5>@* z%6$z%RO(TX3;mZ`rJ`l6)WJ+Rx8|2FPMNRB{Q2s_=U6Mhs#W2^!7Q)czlsWjoST@K z04!pHhUzcZCVroIzxH~b6Ge_ysi!&B=H5wlSIY2MmCd(ql`N*AF_aT`L5TjSihaC9%T_d<8 zJQT(;m1O^jfVDYg;hBgLTK?D!cSU>ALRh5;ku?Wf2YXjD_`!p1CBkz^&OBSw-p^jO zAr8fwca-}3mqhK9lgl|8xd(x`MbiR7fl80o@|BYidxrF>XEnM(7e3z|Gw6`BrH%fM z)L34K^!yS%?huoAz-Q*3zlcG+!ckM9(g_rXg))3tw3@UR(eP@XxT&8y*W?VhP+9S{ z>BPjunu;|#BvO>5iN5KNr+PeeTV0~D(Xdg>_)IhaPt|=iAJYR3x0Fp^onONfe{9d) z!Wc!d!nbc}#{)S4%PcJ|OD0WJ%@fA3LTRK9Z3GM)#{p5QpbTKlwPMF zFAq_O7oT6OOM0@cVE1S05rnbgd`-XdubMdd$jC@^9K8>sXJ))qvVBOzY95FBhF8S{ z0|Q9Be_9JhkkGZJKA-mSAqx7TCpQv!8sjK2oq)@4cR(k=%~|;7jhJKW8}|#F^5|cQ zmVXfeoCdNVQ_=z)kS&jru)S#fRVPM~R1{4QhzazVux_Ggoefo>ULNUaZ*DF?)xg55 za*k=SAaU`AJp+P4`kMUH*u$>gVN;u^!pt@WiRYswO{T`j**0x*=_sOrx-Ko%NDCce z=z=19_ijFuTEff6SK}+%CbBbd9V;vO%9SfuuexCG;0dA5c~1>hC5(?~0I&c0bx!@S z9BRKM7k(o4!hI9 z2?4_78`o>;=M2#+s9OtfwuHU2G|elfP)^85Msb zX8ea`bz{zs0D4bP(kJ2f;)!N~xgDLZ_xWOTm=6T-XU zCbp8K?1#PEHZUC4$@}nHsvivaaPBIl9CY!C}n!LddI7 z1V)&o`<^|PB|fG~ib_gN=k{J6ZZj2ou_Io?zvRQY9}psPZlZR2B3*xdbVxH;&~N5Y zlN%s*z=!wET?t2Zb?2s9T{o`%gxVQxrDA6g_2JyT<~wa~+&`V$``V%L4>rAKFrb7x zJD>8<%veuj#%c4{_MBF|QU|;z?Eid!zBa%bktmm5Rd7e7mILihpQ&H>n75b4GIXDv?N-gBspGY|7TM=a?MrUE8^f96$pi78(DV?%)?6|z}G zO7ad+LXJ|2ufKf8YW^|$Kf>Y7f?%ciKxy~x zqMotJWo)?X?Pe)n#!8WTBB#|t0_zyvxnd0$Kfm*$=!F+=cqFLiV;bu#)yDfWc2b@i zicvj`6Es~P#Ky+jvK-#m6CoIMrmHey7XrofOkTEWZr|c`r)(Ckg0$P!En2cS`o^cp zn$e=J1ocI%y^_s%KF|*)l%@T+qg@uTefsev zb#;3>Q$Y&yBnyi#o$c366B0M-PG6NO1?Ac_xGg=&Og65 z9oHHUgc|=JUI2pF9`sk(HoN(zMsL(R)StR1xLSQrh(PlV10g~s_elY~MLIu& z=|K3u_~l`3(?-H|o=T9+DJdylsxmTnMG*rcCg}hS0KraQ_7 z52d*bn7bASyr4yXVI-mR|E2zFNCeQesRs53E|P3zdv5>+P%9hYwxaIpK5Ds_xxQ!m}ClT z-Jm1EqlRgxO1YEC$bI?Nk%WVE=E{a9JbZW-pfXBi%bsH|PhTMXio3qVO?dea*bXMA zMxLq&+rceq{Mz~&egcBYdnEPhFt(52pvK?UoFdu>&_DZMWSy#M_gKC2-oKx zLrit*0a8XGB}K)@JAmaM>?2&y5TX*UJKNaVc~dCc!zvIqKcG83x;o=;M|in0+qgAV zCw!~4*^jqEgwYKLmqWu2c#ZB968g}TO?Vk*Uf?o2GqYb+^(X!~{{(m^zNHmhQTI%o zve2z)EKWpbn`74~?S-HRpxE-3yKb!oLXkjlxK+g5RsVA@X}MQc5gt8yb}UXf4L2#3 zmTC0-8ydnS%eR*Y05n-ksO%aGFP&O)XqzLeUIV#$^w-86dv>E)_V|OnAF8TCxWa=4 z`%J8SqZ1P?g4Kjm0Ba|H*B?%6aE2leT`8vs^#Ukg(VT+j5{dLELZnTJa8PFT4w?uc=2R&H?dRCj)COmkv8k??v{i2tLe0X&|B8@&w$pgC2q0(>M`hfWCFMCJ0EDO zA~~_dZEKBQ{LM=c1VK4Gv<&~JVu|X$X6s&g^y?>~ZOr@Eoj{bFH;GDsk`38{O$c-; zYk}T+%Y(nIUI|pRt|W@~)g>E{@L^r1o}QkL+75O|PVXGfS?I$QzLY}kI8ZF<;=U-HQw9`D?HZdzCa*Mc;>RXpmz@ZqT2T#xNz zD?(e#5dp`-OX~m&UJed7aP5zeT)BX59_5`Wgea-1#s|Yb;o(vN8vqdBr1#=W>rdn} zTL`SDy0io<*(az6Ka9QlUPSRo`t=d*> zifR@#?c)aY2gD*jhQ)#N^z)NJ7jJFc@vPAy&+IP(_RSOc2#QYI*gWCr#?66DBN@H0 zP_eWI(N%OKnxdutZ@)M6VBZC=Ak`~>+>V0!$ur9a@9$c=SRrEC7ma3?@y)Y?a!Zd~ zDaqumIfNkt{xMm9FOieZBw2)`qo*@w4R(vOQ0p>@MkxiM3PlD~C-W_U;a4 zebXiISe1trK$sPZy&DV&%ACcJ0gNTl{ zmIFWYGFw>X4p+LaEYyS}r?5;}e^1sNi&v=1>P5$I`oa`VKZvW%^)N`F@yKeS zK77c}xd`~xq{KyFUE3_bil`x=;pa0y^%p5Y$+789FiN0pVJ&WW=B6vkLMLDZpu2Q) z{otf`aZ=i`{bKE@9bO~Q%`+EE z4Yp+)SFgIHF+U_P3~GZVw)y!f9hXV4H~p7n3weQCO^0$DaL}Fqv$iMisn5;9eZUpo3@%b%`>0G;Zt=O{EsYR}2Y#3s*0SOzb z@02_U)PI>!tRJ&ECMke{CjOOVfwQAG1jZoCw1#MD^Q(B$c7s0C?Vd#Uu8%>{tSyZL zLqlz#D?t!)oFI|P#(F-pW#$p?v6-_!qSRX#Ni+AGU9QH`Oe#6%bE~(P*8~24K)Z5j zR?6#;gYMnqVA8)`E$r$u*G(!J-gbI|{tK_@^|hY&ZPEBcEO%N%Z!V4QaUN{UyqJH3 z5&sZfs^nlDbiefiu7kgR-D}?wsa!Ygx1bsPaWQwdm$h|K_T<;nm7_k)Mn;;gr@PzX zd%AmXC8w~e&y1U)p&?Z45PM{JpShb`K4B#bC50ssEI}z7LKYs$ zF@1d{-Y_OP1p?<^QEsL`*&jM7^ag*X3aJYE_ny#UUL~h9k3Ks+c5HqDan(od4S1&q zVNcO4*s1>^f^kQt{QLJWTUrR=oQOHgGai0arO;rwK2Q{*3O+UrPA-)6Y_TLLp?C~3 z0ZEcUTwEMrTxe+M9J&fglNUoHBc(3wsxqIiRQ`vb$xB^JM+Zi)#NcQ84DdX`wO{ zNC1G81-9J;udU1r_@DI9{EOF@J%7DlVXnO5(UBo!ZPXiLSP$RM{ri8|$T88;wq9=T z^{lKpFdv4Np<>e&?s^cijc*s-H(^5f|2DwHIQGq6LCeSku?@VJWwi`FSN3g1?@>JP z;-Vrivs94I@#wiEO~kB9iZ&*AgG>fZ>YxzM=3zP*C@J&X7L^&uc!;DCd#Gn7jPKbN4O z#XGGH$e$<43aHf$jMQ%va*t4k^wDu*KXMM{`Q}y z?1VH?57e;P{k3(;8Y9EQ0^7E|fiDmE1@@kI_&>D8;Il)b6L?0slhr(3EFBF$wRVAK zBPk$BBt_RPK-G!Pgz$WLhvcbe zcdBe?g#zdu=zDCdW9VeW#gzWo=xA_A2&X2*{#M#_{FCOv!uyb`OiLTtYNuSYdH5K* zk7s@but$B_nezDXt<6jU%l0wMGcAMFxis1KeHSXu_-`9^fF{`AvROozqaN9h&Po%E zDD=B2bM~!KTEpbu*4T?ccX}lFfAXC~>!9!%t{dwK*VxQCef7*PL8e*bRBAo*60(Y* zse&YY{p==>5;gGVO)~*kD5-c)lNWZGMfc<43cV*Cs(TJi4KbX^2ZS^7L)s@<_?NOkD}wEu_~dEJ9^?X2JF!1J@S?vtQcI zde*`M7K?13wW}GrB|nz3CvQlevz&z$B1DWxTqvII*p;$k)mpa>9!*H2J2u>ea|*uO zZ{_8-Wx<3U*>_W@4lWA3xlYmFkOI_E_-z+ST-&!-`Z|mc80PIP8NIYZgPzOCJB$D z=gvL4_3xJ=rF=u_UDSiknlRywf{m7QoI(c_$N`;#^c|FpWabwg%YoJc>cY>xhe_bZ zPt?XsZ``<%F0iu+-a8qCqer(ZI&g93U24$2nf>O?{%T&oSvn8B)q1$C(9ZeH+sl`L z|3lKzzM-rexg^BLqro`?z6We0DpsXdeSJNUewS1c*w zK-$|%6cYL;p&n=f)keZo8Duq>dd?Gova_R=+f8QX2B`^o@+vA6v@H;vXv6>XT6FPP z`B>*K0Rl1`s*9gKeq>x;%_@ZMU7rNuB$uNbj3P?4!u+Fqd9!nKo9hAova+^QL;T7COQKtQ4FW#n(*5U20yoSb8^lD3g1GmF4q6)#`z%{{;{ zh=jBSlgT11sHE49%`#xrqa;o{AU{GsJ|3?yHZ|3ETpflHX!gz>MJiHKOr0DKeAadA z?!AM=(5L%GG7ru+KB<{o%^+#;`6a*+!mRA<{n=9HS*v=7J#DLHZCq-^S zz))GYDGLh={p9dJkgfKm+QisNLB<$_ge+XJ|Ge)q(;0{!WCVif_#Bbz60@x|77Kji z)^)?j-n4|@IP1)z&@JzL0%c8e%HnBjkd<~)f1G*T%WYaKx^kwSn~yK6(E;X=UZz$2 z>cBTGn`9HxRDZych#i3(2wG;R^m4yoS#mV-EV2kvYLymbU|KPZoHQ+1c%Yy`Eu%*y z(H4{Nl$tNvLH;F8ASFwfR}023oUIL#qarnNL_j_pVtwn;e-}Pmdjch$rV;FNKb}}toKAd2iHV`A}Vh>^G#1SqESO^Ji>&0#sl?ZVq|Gm z?#9E14+oXc?wp%0@G>&u;sozC#0afi^5Nj*hkn{3(9(E_^sl+d?6v(q@clw%sQJ(n zT#c*HN|iU59N+#K1W(ocBB~m2X+~QHh$GM8wyHjq!F2*G*~I!?$yfc-Fr!gPwhVix zy1#helFB{+pTfc#I`Kw=U^RA2a5{voA~CRO`0D7eY&jbrz;~BL<@YsOlz5YoNFc=O z1-OV+0@?NSVx+nz@w}Z2qpalUWBzm+w=gek%+c9emQ^gh+&d)Oiay?+(NhMN{QOy` zPz>5E;JvUFTI^W<_vU7D+qwN!MGD5qJDq&PLk8+W^k@}u08O~ zRI5OLee&habp_1 zve4zi;}|GR9i3wkIR{&EMGq`4I_IC7rF`k!7o)&?yjI$V_$V+bDY_>f?SI{=qVH0+ zc4;8Ui~K-?Q-&oR=eP#p=Z%KADp$^mw77~g*MiwHyUaToZph2o1p$wlVxc|3@luX#A#Z4 z`&pOt8p#;vEz1HY_f#E;7g1-+;wyOU^Anb=V+QkNvH7ADsowQ1?e>M0O#ctr-a2OE zDLA)K68~4WYG4t;jszjW11tpL$Z2IJKKxnv#Ubr1Ebmq_cnhzWzQ|r z%i!Q(tY0Hpk-n&uliDx62Me{YjypKQ2u11Ueri1|}H_{@hk5eDuwg!5H%KrEh=Q3OzQd-T&uP7KQxY^(q1_ z!-Oip!_$n^h0vuNUWaQVJ?#xi-$rL$J*qE2!FQ@vr@aDgM8j_#0FA|t0tp$JVaTj1 z->0A=0=v>oH^t&%>}C!hVSh5!Ea(hqf@p2xDr#!|Z9jcrqI@eSEXkbxLb;qxk1X|z zvZTSbwO@q72 zS_dmlv%Z9x9l((CG->jgF!qmNsrnWk0V?Q*E=a?9EZx5`H}b)Q^ya~2*p(pPl~q*G zWG8)qc1J*M_jh$cl(nX?@ajQ!D}f&|kzSBdVaDQbLY;%IkC`KX|TttkqEXO%^UEIBT`08Ctu<%1VWOB zOKeigk(?vaVlmTIzRz zv!#{kR_cszjrX0X*TFI}NsVbe(~|;PD7(*%vw1n;$7fdL!k@S~FGG{7UPj_p@1vh0 zaF-zZx7B~#AL|4~C1!)5dSWnhq^0M=Txroan8i$UkY7Oxs~g0dfLnn?U0qqZ4M-bK zZ^ns$I4lU-a?DNvjgPN?W>E|)9lW?LAnxiI+s#X47@&%LCfm?~piOqp7b^7uMwIf( z6*6Bl{P>@h;@`#(k&F;hz}85QYuBv7l;W5wG2;TzEYPqO>?>Az)wln^M2L2pzK*+s z=deplPPx4T)05fm+1WU7S@`q#A@sCzo)X)_D$qhO3U2V7$!K&yg%#YgwJQ`%qc6(IG-wk77M3hoa4 z0h=N{;rAzY4X%GV?_i`)+a{ubz0*<4I;Z?yNyF%yvrMaUZe2IMdurc~();cEDhzV& zTT*#6ckkY9Q@E?;c`c=qW?n!a@pDJ}(WSOhM@tJ30nP@uqG0B1iQUDE&49I#k=jap z6L!favv(ibwtoG3p9l@kOL;#rCD3B80IQmnx(dLKS#~~p5$)755R856D*FRwRzC%s za!n`kz_plQOf)(v@M2~Z*zk)hOT^;rPJvQbkHu5#y};(VYzQYs1Sm6!3I$#!MVDko~t%Kf2 zZ6j8XANtrm?C@k;RnlO!bN-DDU+i+OaoI`LvDljKpuR8O1rkLeb2pLL=46{~oZ#r( zJy0Hvv{v&tNO!+r6#H*0O49ZXH_^ny;Yj=Z`EzTl`O7i6ciU|F#x*{9W9W0v*$$#HSFZV$9$mY_L1bAxafR?bL%vRGnuw zC~lRdc>=CGjXOlzQcc@-fE`uG{f?UjNYh&9J>E?4{PPGN8OcHrO!;|xsS`}_BP z7?p5kn3-Z^?kO)Xr@(+-Sa@Zm<7`V<|KGpuaKJ*(DrvhJj1^&$)VLD^Xr!ZEpL|NU zz|n{Xvi$Ssaha(0Yv8I_H*LBKJ*ZaeIc7i+9B1FSJ;uZjd`b@=KFqC-IEmTa5{wRm zT*e!Lx<_^Q0mbG?ElofF|9Jt1+ae_KC=@VXSzqstGRNe>&7Y%5TEf#)S67ES+s)I{ z;P~rJxF4oQqg-I%x zj3S?qDUDp@d}vg0Jm`$|(h#;8WrfcQ$OO=;p1#Dq3kS0nvt}v@3JT(Z54Yy8oikxT zrXMB#nr|8JeD;F|?+R$-1x$}|>pK63@tMhvJpQgtLLb-H8p*@Dq*xVj_MkHcb&f3y zu-I)uZ__hg9+ZNuOe@V6dxTOB#eB>w3X{OMZ-<6rQs-66m)(!ugsAVwA3^<=EuJ{} z;?t_z-=R}UEq~1ZMLci-+ng;^>EGeH$9Js%#Q-qGfDkh8;Nt%NOJjM~06;94Xl|4U z=C69!X}0nr%J@v*!JrhFt%*g?4VkLM#K?+jL$1q()RlG5@Z)DFeR(tOrT$iyj;F*X zyca<`)V%1Fv5Q&7&i1jPA(z83H7(VLHOM%#7}S4l^OuZ!eBsZPK#)nb<&SaiYidh- zbVPhP;N`EG8_z$U%(1D$0=&Pz0t`cKXFC5&0S3)O?~CWkd(+yIq`J1|{#?Dx#;xU& zXTBS?+Lc^9*VXjAuV8sR&8>mmhB+4dUpL2zxmSsksoER=<;%;;%d0N&^%F7{<;@fi zTb_u3n%(ktJwa>JXB9&4Dp($Hw82lf8JI7Z0y*+sRuPJVhdYV~2M6^?Gh8dLL$3Ru zW&Hn7r*xT00pUHH4U~UqVMKl&p3hqlchQnYS!c$w3p#QRRTB=jLo0hU6L{;EL9%*# zmK!=A@*E&$d8;tNOUeHKy?-$v5(qiD<{1qGm~(_3HS|CEUSN&%E>?f{6fKYT8BBZQ z;T?Ph9sTUp+v;<_Q@fqk6NrcZ^`x9eE-NlBwk_08(}Q>5KN(*pynR!(!=$QnrB)Np z;VcltwkV#}N3r`Qmtt^Kz)A$fqc_NjkeJh+KCRlExMAh>MCL+-CIke0lBA*^*be(Z zHt#Az-RcQ=Q8x0ake`4(CMyw%MNo8)Lfl=wlP3_dK5M-u~cb;_^FD68Pn%xPH z%B;pAD9iBOtR(z_umoEQ_U78vt2s_Uw1unAshy8ts0fkhWf;zIFfnPyDHlGhD-OOG zU9KUV+kx_o()8vyAK}a@Sb5IL$I4p0Fnp}JY5B(n!uzkKjAQqCj1g13Z{ECN9?Aff zM*e@v!}6azGlk}*e#MpAV92PSa$#|9GBL-Y9v3fOv<&~=U|}O88`#ZJK)BRsCT0< zAp;b#vSPp$S+xYf6kQ`G>Qmsp>fu~@(497%sE{6)9nS03RXHZ9lJ!XvFDN+s)Lg(M zSKda4U`IKU5RE~r`@yJndb!N0eA)q~Ov0^+(C@_Tb>XJDOOIw?ujcx7n4ioHfXY?X z>=9BXnGk7<;;;RzYx;dKbt?q9#$n2a&&h7D2x&JjdE!prFM21g9nIr)Bn z&n9|nFrZhJSiLAYH8q9NBF4lQaGlcZXJt;TWs@m%>sncd>5&ub#(Jx^6n;dl4dTGK zQlIIt6{jYK3_LA0-(X8L3}gjcoz573ICC1^ILeDFl>x zQVE7_AjZ49R}GfD8Pv*bY|1n-Fdgpihf@^_k0YHh>cl+2M9c^3tZ)jR#=xYvl~pHm z%q4YmhlIHJ(F>37C9n}Pm~S0Kg)fK1#mR}jOX%YGpKy((ox+&oK>?Q84jlmRG4oA` z;PrSX5=@tx_k{ILQ}*GE&76B44rbT6THJ`9GKEfnbH=Jo>th_(lj z2$z3hlGg35f7=SjbL-S4n^;y6B(DGa>G0Q(asI+{%rN*Al3^ycnFSuI*+N!$QFMrn zc>+k>u|2qYR5S@k0gQfvw&@gh+)qf4(*D)`<040G=PE*aDC+pah=CX@&gVc8`L8v6 z{-3%>8?Z+Wcf!8hg`wt)gT@4|5Xg7qBv3WX(<54y)&jf!>j1fpLuWJ(T2k_7XJ=8{ zYhch5GZWE^DRbZS=QS)B=E_}K|NJ+e>NDft$5Rf2Gs1fxm{3isijoqJHp#=4=K4j# z<$D=o3h;4I#{(i$7Xzf6M>E^n(gMDM>9-~+_|9RDSG2GeGpMi4{=%5++*pawbB5Z5 z7YJcdL1^kR=oFbt@iqnbkdT!0qHVXL;uV~OLY{)i<_NcUZ8S~_Ve<6S(qR|pCkDp< zWKwBQG5T|_0dE26>;NY`jORenqi_M=CQfF5Nye!y%o)(ys~Fq^gGzJu$FQ8b?;NCM z`2T4Qtz8|(*TCD!gihNfjaq-DT9tyIu5aMkSG}q(|F~k@F+}N}a_i)nm?;?6D$2{p zhR|}iwY7Pv0+$2T5ix8zuzt&)N^fJTCFbicCB48oW(pR~Lxyrs4kJ=3xD?!hkE8#oc^g2Jky9= zQFK1SFzIRAuQft?zkj^C@S*uneJTs5v}MtLz?^;i+|h)Qj7HB+%$Dw{Gh4|*@As+i zS9b3v3470s3Te=#sSN{B)fnt!*-HHbKZ?=(OSD2{|MAh$UjtB=eo@0Gpj`fe)!M73)~XpjmyENE zpz^)|m#>|ALeX*b*9X!By6vS+prKVIBqU&T_4x-3A{t6q@v|M~AkxF72I`~mHJBC7 zdT>fLT$8`+cL@dFafd>y$Qy>`5KTdFfc|)io{5NEVE)A3g)*330Q%|dXqe$TAX#)! zUS?)LZwdPS?qhy!-MZCo?=~33ropzCevOqGK+}lcsHdw-C=Crme{b(G;JzDqCtDz2 z;nQwS@#JIBy&jLMk6ZHN&MfJV-GWKa-f$G)M{hKK;ESttcYiuhp2rMs9@(?Tgs ziTNLdqzaw}z{2hM;ty65FV3yjI_l>W0?n3N=~=*%-au$gh3 z2j;WErO~sX4=H#0G}P6F2j94H1J2X9$jDvIFhu*|`2?qfvSjKoj^G@?@6FB6;{(^Q zuxR)#oO!`$buFbA9s(6tyM}Sf#j0x$9bN&kObLS;a8m!Vp_Mpi=jSo;wU3z#jwcc0 zg5~7oK7RV->*E7hjD|bgrw~2(Yf$$jBjh1S^u4AcPeGswXPD;hDbSN_HDNl%W=5@N zzkmJ=27IQAWqb|e;_gL|k9+!{n>Vh6<8n}!N2RS1A?%lqDwh}U<_vmo&rLT`Z zfN1G8gMGuEsr%xj2izWD%>pPpx)G{Pk3T=NkUrOQk2$dj-j0}SRNvy+U_I@qSR z!+SvaylKORAt0`#({kKAJQrPEW#3JCokxQnEQY;ynbMf{5*`d;Rj<)UmCkE`!MHTt zZbayjZLD6~pIG4-G!=b5py;IV0e_rB!F1aRoxgAa;s#NW;{;Fvg74JHlb8ki_51e^ zu*)iXraSw&BmDipLJYjD+76rK!rZ(S2w}aDoV@(v!UD8{ zSf&|vKb$bQF6M+~a|X-}FtrA8z#@>JpU>?1P4VEd4L#xVf!@Qk^D@r@scAZ*r?*R3 zR#sLeXgridc=ztWU>lAIidMb{-0+}%sy2~pwG&N(D?v$p`4V>R44~fUAb3$|-2D9f zA?YyB^r=e=VKS~xS`15QLg#*(m8smR^Wzr z7KF5&*X6Ut{QJzMW5;Ub`J0jzscs&39MdyAeGU}{Xvoyml#frv%-z!6&20uekk)WXSu1ls*1LEK)VgLc0I~XMk4uz4+2C9GSEETML4Ai;5!GA3yyMbI?0crJ<(jjrU{*gDAm8 zDL=QDqaKB*_cqzG%W@p<&VvGHJpcYI=hW>@x`4yX-l+V+#Dx!OE0bDI&vyW27yKJi zvYmriTop4B5clRIB-e=RmZUD+E$5TPfd%s&cbKR6FmNIm#-vVsDfZ~!R7baZtr^_g zYr|AYuRY7yN~8|86*l--c4_Y%yF&@Ow($od6BZ4^xg85r%dTaVj=VGb%xqj5BdF^*r(L%TJ9$N1rBfC zy_3A?F1)+kz$4;ac~(|d(({b1b&@pA`}gm|=q%>Mw4N}>=#V=KMQ_C2Ef7h;DT9rF z5XS9q+*9SKuU$JW3*V!vA+zAb6&Q@3EvT+uyB6GRpQ%X5tt?uJ$>T~%zT;k=(G)rk zmfFociBFQv6o(0MN}qXou-ES@)CFQDpfO1S!~Q^lH-njHK^AVkH4g_w-182+d$2uR zEjcb3MIWV=;Y@g3&TnqVVD-Fk{GsY3619c1if;2|t?_#%p&Kh&OHwe$3e?v9MCSUpewF|ZsS|J9xGw8CGMQ`lu zg&RrWA{u}2a=S3{a&l-M!vOfm-;5L-B!Y|+azur<{_+)<%_ecmt^%CQ!`Djw<5(}C znVRoki2)xB;j_n4U?6gAY;B*v2p0=P`=(_7^?gN!9ct?u!C0E$c{Ieq!T!p*;G+=^ zX8%unUm6ef{{B7bh^XwGw8)fWErpULqLB_+ifIr-i)f)to3XZ-niLX4mMDd?O_4%E zg&0NUs4UrLHo0=`+F~%PsPItBo7BCimjICL6GV$_${Yq@BNn!n>Wu9jN~6!jP%)6C~k&t*E#d7`x7E&=_={`2v=Nczd=mlkfKpbi;MmpWcq@Ij9iR0R^?c zHy!YZ4@2tO%y=-OAi&|x2!_$`0*0<+eAFOvWyD;ejEX2qj!ep%O64W21PmJtn?bWH zD2NC5?)BVD&yrYb5m=Yq|NV@+Q|)UZ=GuRKiy>c&-;KRFs1GYTOv`|u3JVLd{9|Cu zO(iD2kRS-aP_jbDdd`DjVrFq?%hV6KFhOT01N_3JOx1%Em0SjpgR#cn*{M&#$$d+j zUtr(Ct16w>`1z*-eU*hP9D7c@gvGs6nK%6kOEeqaYNrz8jcb@U@(l;i)Wr-i3#7Iy zXsEzr-q4RDL#t-`qqDVj9IO>1p|Y}aeC9eRrur{#-aU)f9g_azr)TW6#pKTc+SfQ! zvrPpvoxub3u6W-^2M3`^z?^4%E&BXty^8Q_N%;ZD8Y1^j=zE958=$WK95Jb4$w$Ja zg)^X@uKn|=FZPn@8`k96HW->hirkm^=qy#QTd>r$#JfR}*>aAVHEj`c1$42i#SX*F ziQb0yGD=FcD%lc?T=e`?@jKH>UO=KS^XKEkU1rS+57bHF&2j5FDeeIdS%w@wLUH&^-NR+Z{!jfedFyt|swJgoC;yVX&bC7WK$ zLTjo#Z;t=;^mFr>C$?eW%_#=pFH|Xz@X1^=I(p8hUi(6c7!M+<1?s&QIngTSt0FGH zZk|$%d!^!h=rYabmJ-%}1j=C$0sqJZ^*5ZN8v%122cZbKqnIBdR{sj6z{oh_I-YiQ znKxZ@#wv*y^=n;S9r5}?K#u!l?Z!`}A@^kO;lt+4VSZiRmWoqnVwM;BgHBTk{l0DQ z-W5{)iH08CH+uk9!-%79pZdAGyJ}AaLW@Vr%Q%ntSPS{`RJx312qI(4V+Erxcsol9 zVaxw^!!V9U@h{xQIB_C6Dr)+Zx$sV;UY0YHWj&DG`z1+|kxE?|=NWg_UC+#nPCbK& zfaA)>>(>t^xs1!XQ%ttyD$LK5Hqh$v& zQ*YDO+8PQCVn~chqc;3OU~P}@-MbgFE|KGT51k9<1uj}*h% zFWK&d)D|b(b&PX9=xV)oNO&dcC$b8{FrfqbT(@qWX>|*loYvM-;Q08a5BK-SVfI^W z)`A=lGGSc76;ZcK-V0gF&65KirEy&~d{f7*Mvcj7joOw&0fSD>Z^g0l!1Jc1l<ZmSEiyLv&-Rh8WojP3=9f7$70oQtmG;9G@gys`=(|ziQXGs} zb=trKGRx(Sxq(ck!1a2aIo{lPbDp7ZM48z_fhRNed7oOkw4G za!ogUvb~EhIk$*uC5}|3Srwh?t5*-AM@W39rm$CRorex`JRwxWnh(qFb_ykep)&$J zWvt>cU>KTrG5$Pk6^^+Y1W<2>kf2` z?x+4^Soh(H8!XDY_OX+I7Y`45gB z5yPf%vo;wsPV) z%J+#-Xxju+q1rkYAfG2E?`Ucf)~2r2_qU*HvpydBcBH%E1e`FA&Z3sD+n^#t8AE@F zYUO9!(gR)n6~kx6?99w@9TG`pA+blm`}Z!t3=PANdPH2Lhb2s{%I^m6F|8zX+p+N5 z3Eu?-z~x{S9KapV-JZ@lkk@;(cjxa)yH3}55s^nmWplu>EBKu@oyqcT2-%S^Uz<`b z)UPf)erjr&CCuEX-u_Fz{)axf_Jo^oC}=jPuA_?wUQu+&n?%11IBjIMlvE6cLYD`6 zBxV`@C!SPx-RswD)YNL;v@ny15rRldIHibSE8ow*;v4dd?i~}oz8V?{9S_k{pS%r} zk%~KaQomdt`pcCbti;6=XGv9hXf}XWIZ)P1%&PLV43`EE;QMb#!vDjcu0>@c{DFHJ z!aBU7Lm$x)4xa_KRWFK31Obdgw0R^0f{Tu90?P+x?RK&YP~`Ekbm5dbh}ElwvX(Ao z3rK@G5ZWQQW^K{NUorN3?zM=KY$PsuBCi0&HKVzIIHf(yZczj%HaUEsd@f=L_=X2T zJ=eI(ku)q`!J<~91s@LgU5fp~d-tGTy(iJ?g!4PL&agp5QmYVYW)uS4D_R2g`8QWs7;d+B$v%wwuqa5?kFv? z(iS19k?T_qsVofmrfrF9cZO0;ricBRYv?!ScXLDvZG3*DqCFxgJZBvBov6s z%NP9nB>qbU<^R_{5l=qp`U*{$w+`4GIFR$itFE!}!In@`x58=AXHCn{w>0%*%~`T$ zH)RcN4vD@5@60LkYUuSavf92AA<^A%od;jQ)9jl`Ky2Xa@ct3PsDuJ#Qh6$FV>k0_ z=VKEz8-%(E^Kjs97}^Qlu4*-)0rQE^%Z;wF;YmSSA%YX+!9zfi2>}_lBskPj@)|jO zt3LAi_aTS}A1a~Rm~LJfHXVd@UT_g`8YrL*D~S&@+25sN5+S}!@PWaXsGPKY2~)za4;ML5(M%H(E~;h z4zV5y&?cccX;X9lLR-Zq#NSv?Dc>QSZEJ6r+XlcQKQGVt*k|b_OG5vRZyG6skZ+O9 zk_|zBp}%Zw^ap=HsScmjof(1qVO9V3?E}Kx6`z(?0m1}2Kp;_I08g{=3$onSuUNsy z??hKcz{{Fwv?%d$anBsY=KbE?-(LzX1Bj_&zxE$^*C~ALd;w(@8AOXU6oXH^KELCt z$h6#XdXB8ttDWj3d0iqrMWm3nP5en|HQMKO`~?g3B8a$P>cN`o>cO+P9a`Ck-``aQ zi^@h`!4M%otVVl|LfLdQ#bEzaDGX<&A*DeW#_83MW|J-w2`6Mn2zSD&jFvlO!SaQS zKLqda4&F*xdyv@e0kB(&r|mvvQeZv73CoPZXvfyj(&Blo4uR*45}%N;&$~B?w^)DQ zR><*Cqi4^ajmu<90fnJT{-!|2(4*KiK0jUa-XY8o30FmIt%UqP<${|r(3 zAB4~t&@Se0Z29pwxREJ-oCD!3snQc72UoQkEi0G+#DdREjWi)=fQ$u2LZ+s+HY+Pj z_BAdvPzw%+iarWmPX{*OKs&1Reci08r#b=KWPK$Ie9oQGFIo=yLvkW9&cJ`uFM?;L zS9C0BpK8SZdPw=m83;$?kISO;^|J8b95nukVRzlSF_q`v094Hj7_$f$Bayr!BR%!fYW$ax1-T2YBWLR<=eBtBSJ zDC;1sK25u9pRKPPg_;974?xlnh29r456%*0{ZPPq_R1J{{Qv%^MYc4OK{Z5)L@6&2 zEq4PH4!wW_#rT>((0UZ&@+n*-pMI^4yIkg~>UfCVT@8lQK>0#ld;28FSmG<(pu9W= z)`MJFT)g)8dB461wP(!2*Y`8-y3Wpp8tcFR5{CVUnv(y$!`ZkC$GTiYqwY=dnC{Kj zw_|y~iHagJ=n5O7maM$Ib7fd)v*>-t7U&V&sZoGy5ICKSj=YlKPJ+Pr!JGl&3sWMM zMbNaLU)?|;DHd;z9MbiWH%(-;DAAaipcg@WVckTe&{^U0z$f0YL0xwfSWBvGkO4Q0 zj?>m`frJG3VHV>wKmhz3l3jOZoDhPB(QSTt9Fq@pkwDOv85_IVyCbvG(NTJ)=%=1VV)qFK=a%0%rM=So|11OUrHXvY00}lFR2Tu<}pyxLB#>&d- zPRliC};VfP}nq~rvyl&85jx#sZQ;ETbp7?KBffO zSVZfA_y)vBM|PjKb}mGj7#!sH9w#G}VcQeWRZX&yyEx?M=`;Ar6E^} z0>s`D_O>QQ`+A!@Vu81PI}Cy*u_%7v_OiXIh$gVwu|IG~M>stTl}J+-0Sj2gq2?hQ za$$BPHiMDhH&lM*@@21s2Xpgkp?+iEhUf$qa(rNhp%8)AuI)I;J`6Gl^I&%}KP^#> zIAm5DE(`-qZK!FNx;|DC?3qH7mKR4NmIiGd#AX!}e`^b>(SGef$UQ6>qMALNm)by$ z(F^_Q;P5GX_yr-sEWPAC!UOBL;-H#3VMJsTZUQW+X=zIrFWyR6UVBQhqy%UAT`b>j z={hqxTk&TqGZR${5br)q%cnl+p*u<$-exUp5OM(^48Ht_aBr?(KZY5R=gRvvhrIhj z1S`M$LTv1ltLTN16%Q7c&d##qxl?K@M|cptBGd-o!N~;{9}IFi=y4c06POcd)NVLT z&=29}+W|W)7==g~mwjSg5-Ayp6S;b>4O z`W^(Us%8tncnW)!uzdw$wH=YA!oT5K+>4M#T(sf5?P&mMuNyBMj3=TAo52MB^UtB1 zuYomgq7VRd&M|3`q($h>%Mzc{CXjNJYWbQl=Vu%CDYA$UY?g!^bO;4nA z&5!n7I-n9fHMYjw1>$Zxmlj5cghOwoDt059q6a}_00#cc=mgFo=S`0EZgdnuk2~*g z@B%&J-i@>2SHmIOKRSV90l*CmM2Mv!rx@=jt`|Bz@84wp@LbgoT9QaW&&rdPL5;U5 z{jhb!27)X2M(ErwK>V+Lohzg(-{MB|ySrB-tIm7vFBIVfgMp14o}z?}a8NWkR6_7hE<4g=X+gJK#*KHW(;h7U#Y-T?3m-i_Jz(MP zK70tppLi4bm$`l$P6KA9Ct~7eop-{dAOpF^Mw2^UB@a#g(xnAKuOW27UQ6HcR$;Nk z)A{LWcHK*DFp^*4V+{-p1XT=xYTlf)(z~I6n3g~b2aGFS&~-ogsZ#BayK zjE6=}27gWR$@h5l7H${xSNLZlcL5bS{Wu0*A2;m>I05gTRA3#bLCeXl?+Z*N2(Tlq zPavon@Fm(G16Z^L?yaOSI3HRf!Oi_?VO9(shfHja`E5&<1SF0qp7Z9(z}!P{58=uDHHuod1;9;vQK$oZ159p3r_x zUG2Japavy%$m}f=ux1R#b&w@F4u1edg%9^PpJNk@ScD5gr}obU^tn-yAuRP z!S>8Z#z~MXZ?JQT;?kv^$a>hhbLY4W|6wBbia1Y1`(%cgP>vz1k~{%NATvXc3!}}I zx#Msr z$-qFAk)?E|y!YF;Mr*oiHOxI89vU=5{%2(uEN~oejWF5N29PB<# z=O%hSGqp?7W3lbi&GPeRSH?mOi;(LXV~r6Bsi156KK5YCu~!*GfP{hhG67C%+@XSIl!g}K*jpZCN_%Mmy>;~jkWcshCqHMq4d!sbEh64mDJy31_ASI%Y00i;doF9ZfT3_F|Vc1Taze4-Y{?Rpo%j)|NtWba6BWgRy>ZU~q6E zLdQ(35n>^{`u4#xiA36PO3UNmlaPjnO>Ko(;tacE>8fqF%VnjdyFY$BislE}?3QlN znIjyB=TVEY-8QJJpVxhHAu>|7lxbDEaqHGg>FMdxmfUw;Oyyrj%FrUpsPSo zT|bo3@EW5606ULAiW->j7KUX7b1JMb3F(%YjPQW?RrrSlNE;r%jf}l>?2?NfSb>#f zfE&~uW15A^#&S5E{th&2a<9?T=QChEk|qX_m(E;F0nh@>m6S}`5d7eu^U1KtWBCA! ze}NtRh0Tt8i+Sn`obo=~AgX}w2ZlsG-cf8$a}r=Lp&K~RT9}z6A#WtG=EkR{PC-EZ zSXvasLEUql5i3veHL$}5j`mNi?vPlaM*^yyYNC7f@#E06gJ|{UFVusCfs)-UcsNf| zT)f+N_v+uL(l%4DXd4!NL$A%Q!z_C&0b8-bs|L@Vn)PxS88IdAP0BI>2I+=}(^9u% zCxd1Y9FsK#{TRbq(7%|Qu8^&wK%>#ZsxG*maC+Jx!4#$Iv@q2m*E)Yp-V#kMEHj7? zh*@H?4k1)%oQK9p`d||uL4Sl7Wg0PKxH3JpbL+p8lw@V!AZP(%xa7@~Pg|1MFvEb{ zq`sRsqVdi&<^xYITk~5T4&bLC`)gNa*9$i=Gs`lt%0L;0D^_nBfd7C5>b)1G_jM>2 z-~@uj)ti(jCMFO$7tJ;s>zP-z`(tIln8owplK^Lx+;h&mjG+Ra{CTD1LbR}h*BLhC8i=*7|XW(qaXP@{COH!@m_otGXZZM9C*f`<|O9mLX)5uboQ7tf06xlO$f zBx=vj0nccHwwD<8n_N)ULyR&P0&B_acy{{^8lCpGHjs}l_B|HYU?i4}+_1jV+|?+VKDP(tE8KU zcOs2*tFl{s@S!jOx`Yg@yW0`?47e9@A|cuj5 zq*WJJZFQn_23j_j?L$1E%QJ{(*y-Eqd}^+uImg6D)#lt4e$5F(i$4daKGFbragO%) z-;=>eW``|7?Cc|Qq{q<|U`xofactgMG0E_}t3*nWF zC6<2zGeJ{I=1yXpHi)CfDEOtTR4GWKf>=6!(p0AP{d**eHTW7pdW>6Mg676lTSMbE z`iUt^B>(F`==}E%m8`MB(3vS_SkB_gr%=hPOsQr`<92sk8N!>7iSs@Nw~j1=HI6mu z+`**g6!u}6R?oojF@K5iF>qhyy(orOr^TW*l;V7*A(1Zs*%^$B0zw$)gRl~BCp7wo~@h9yS;*i9ASt?x(oNLwI;mOS3112NzUk1R`$ zAXp5F(GjYt4W2b@LTMm%)857gv**>t4tvTHhsQX$jxC7Q5<8KKVWM4jfr^<9<2m=f z|Az@w(akplum=oUx8C4(7?C2_6~dWdRSL4Z5oig?M1EPSi#M#U@g}?7-deckqPAi{ zXsj^(z86dy1u>~_Adx^Z9a)(IJ@=EY-7SCfhO|5OTTu`QgKBrEUtS z;SdY$=h~REm%5+_;i=V09Ds%;-B+N;Yp<08Cs}~S|4uolqFhz1%V}-X{HG6d6AeYz zywv7xK3*#7Y`$Hv!Ry5*Y_kC1JYE}*%$lZ};XOfm-3j5f7Exb1iVxC^Jeb^0_(cS( zF9Lqb;f2HKD=bvbJ3?@E(8LNhV++w!{V)B(k-$`11&B|>kXyG>moXB_ptV8^qGQo2 z{(Oj){SN!|R^}7cx`7JJX2tV!PG-LnyNq5C4?Je}#YOA6gLbptLx1c&H)qCf{KrVU zSjo=2;GKEKw-poa(HC&2ol|M-8MuG6yJQGX_RnY7!%a>o-jm%FPosSEc4GRAEiHDc z%AG*m%>`Du2&-OOltRJ20UoatT%CtavXtzx*P?wkp|qDta20neY0(vaqtQ#SC)0?{ zwx06cv57ObS?|O)#6UfP8{0B!RufZ%#m4!@5O)bP>5V5gs=Bc{O!T_=dJ_|DF-3G# zm0xy5KF6fvon4W6q&?>E?K-bSfD+=(LvEe6u839`)|uPb!R7D5o1gAMdCvQ!rxut5 z;8Q*diXUi_hCgWPz%k6yO@DyZITd{a9N_g+w|Cu*2Rfk~wYh?MFLTiT$|u1D>oVE| zD1$t<4KfI1zuOy{$y-zYA9n?yGHbPoV08P#F1b2!+(kmjwU8{sH{Q?4l_=Fo%M;snf&1+urc)T~%Q$D8< x<%)gEkN+3L^gYk%v{ + select(compiled_model) |> + pull() |> + pluck(1) |> + summary() ``` - - - - - - - - - -## Feature Extraction with a Sequential Model - -Once a Sequential model has been built, it behaves like a Functional API model. This means that every layer has an input and output attribute. In `kerasnip`, we can get access to this underlying model structure using `compile_keras_grid()`. - -This allows us to create a new model that outputs the values of the intermediate layers. - -```{r feature-extraction} -# We can reuse the compilation results from the previous chunk -keras_model_obj <- compilation_results$compiled_model[[1]] - -# Create a new Keras model for feature extraction -feature_extractor <- keras_model( - inputs = keras_model_obj$inputs, - outputs = lapply(keras_model_obj$layers, function(x) x$output) -) - -# Call the feature extractor on a dummy input tensor -x_tensor <- op_ones(c(1, 28, 28, 1)) -features <- feature_extractor(x_tensor) - -# Print the shapes of the extracted feature maps -lapply(features, dim) -``` - -## Transfer Learning with a Sequential Model - - - - -Transfer learning consists of freezing the bottom layers in a model and only training the top layers. A common blueprint is to use a Sequential model to stack a pre-trained model and some freshly initialized classification layers. - -`kerasnip` supports this by allowing a `layer_block` to contain a pre-trained model. - -```{r transfer-learning} -# Define a block that incorporates a pre-trained base -# This block creates a new sequential model and adds the pre-trained, -# frozen base model as its first layer. -pretrained_base_block <- function(model, input_shape) { - base_model <- application_xception( - weights = "imagenet", - include_top = FALSE, - pooling = "avg", - input_shape = input_shape - ) - # Freeze the weights of the pre-trained base - freeze_weights(base_model) - - # The block must return a sequential model - keras_model_sequential(input_shape = input_shape) |> - base_model() -} - -# Define a new classification head. This block will be appended to the -# sequential model returned by the previous block. -classification_head_block <- function(model, num_classes) { - model |> - layer_dense(units = 1000, activation = "relu") |> - layer_dense(units = num_classes, activation = "softmax") -} - -# Create a new kerasnip spec with the pre-trained base and new head -create_keras_sequential_spec( - model_name = "transfer_cnn", - layer_blocks = list( - base = pretrained_base_block, - head = classification_head_block - ), - mode = "classification" -) - -# Create a spec instance -transfer_spec <- transfer_cnn( - compile_loss = "categorical_crossentropy", - compile_optimizer = "adam" -) - -# Prepare dummy data for a 224x224x3 image -x_dummy_tl_list <- lapply(1:10, function(i) array(runif(224*224*3), dim = c(224, 224, 3))) -x_dummy_tl_df <- tibble::tibble(x = x_dummy_tl_list) -y_dummy_tl <- factor(sample(0:9, 10, replace = TRUE), levels = 0:9) -y_dummy_tl_df <- tibble::tibble(y = y_dummy_tl) - - -# Use compile_keras_grid to inspect the model and trainable parameters -compilation_results_tl <- compile_keras_grid( - spec = transfer_spec, - grid = tibble::tibble(), - x = x_dummy_tl_df, - y = y_dummy_tl_df -) - -# Print the summary to verify that the base model's parameters are non-trainable -summary(compilation_results_tl$compiled_model[[1]]) +```{r grid-debug-plot, eval=FALSE} +compilation_results |> + select(compiled_model) |> + pull() |> + pluck(1) |> + plot(show_shapes = TRUE) ``` - +![model](images/model_plot_shapes_s.png){fig-alt="A picture showing the model shape"} - - - - \ No newline at end of file From 61b1a13a41f546e537b02ac6e8ba4e1b60c0ed19 Mon Sep 17 00:00:00 2001 From: davidrsch Date: Sat, 16 Aug 2025 12:39:44 +0200 Subject: [PATCH 11/32] Modified to show a multi input multi output model --- vignettes/functional_api.Rmd | 219 +++++++++++--------- vignettes/images/model_plot_shapes_fAPI.png | Bin 0 -> 25752 bytes 2 files changed, 118 insertions(+), 101 deletions(-) create mode 100644 vignettes/images/model_plot_shapes_fAPI.png diff --git a/vignettes/functional_api.Rmd b/vignettes/functional_api.Rmd index add9bbf..68fa310 100644 --- a/vignettes/functional_api.Rmd +++ b/vignettes/functional_api.Rmd @@ -34,9 +34,9 @@ There are two special requirements: Let's see this in action. -## Example 1: A Fork-Join Regression Model +## Example 1: A Two-Input Regression Model -We will build a model that forks the input, passes it through two separate dense layer paths, and then joins the results with a concatenation layer before producing a final prediction. +This model will take two distinct inputs, process them separately, and then concatenate their outputs before a final regression layer. This clearly demonstrates the functional API's ability to handle multiple inputs, which is not possible with the sequential API. ### Step 1: Load Libraries @@ -55,68 +55,73 @@ options(kerasnip.show_removal_messages = FALSE) These are the building blocks of our model. Each function represents a node in the graph. -```{r define-blocks-functional} -# The input node. `input_shape` is supplied automatically by the engine. -input_block <- function(input_shape) { - layer_input(shape = input_shape) +```{r define-blocks-functional-two-input} +# Input blocks for two distinct inputs +input_block_1 <- function(input_shape) { + layer_input(shape = input_shape, name = "input_1") } -# A generic block for a dense path. `units` will be a tunable parameter. -path_block <- function(tensor, units = 16) { +input_block_2 <- function(input_shape) { + layer_input(shape = input_shape, name = "input_2") +} + +# Dense paths for each input +dense_path_1 <- function(tensor, units = 16) { + tensor |> layer_dense(units = units, activation = "relu") +} + +dense_path_2 <- function(tensor, units = 16) { tensor |> layer_dense(units = units, activation = "relu") } -# A block to join two tensors. +# A block to join two tensors concat_block <- function(input_a, input_b) { layer_concatenate(list(input_a, input_b)) } -# The final output block for regression. -output_block_reg <- function(tensor) { - layer_dense(tensor, units = 1) +# The final output block for regression +output_block_1 <- function(tensor) { + layer_dense(tensor, units = 1, name = "output_1") +} + +output_block_2 <- function(tensor) { + layer_dense(tensor, units = 1, name = "output_2") } ``` ### Step 3: Create the Model Specification -Now we assemble the blocks into a graph. We use the `inp_spec()` helper to connect the blocks. This avoids writing verbose anonymous functions like `function(main_input, units) path_block(main_input, units)`. `inp_spec()` automatically creates a wrapper that renames the arguments of our blocks to match the node names from the `layer_blocks` list. +Now we assemble the blocks into a graph. The `inp_spec()` helper simplifies connecting these blocks, eliminating the need for verbose anonymous functions. `inp_spec()` automatically creates a wrapper that renames the arguments of our blocks to match the node names defined in the `layer_blocks` list. -```{r create-spec-functional} -model_name <- "forked_reg_spec" +```{r create-spec-functional-two-input} +model_name <- "two_output_reg_spec" # Changed model name # Clean up the spec when the vignette is done knitting on.exit(remove_keras_spec(model_name), add = TRUE) create_keras_functional_spec( model_name = model_name, layer_blocks = list( - # Node names are defined by the list names - main_input = input_block, - - # `inp_spec()` renames the first argument of `path_block` ('tensor') - # to 'main_input' to match the node name. - path_a = inp_spec(path_block, "main_input"), - path_b = inp_spec(path_block, "main_input"), - - # For multiple inputs, `inp_spec()` takes a named vector to map - # new argument names to the original block's argument names. - concatenated = inp_spec(concat_block, c(path_a = "input_a", path_b = "input_b")), - - # The output block takes the concatenated tensor as its input. - output = inp_spec(output_block_reg, "concatenated") + input_1 = input_block_1, + input_2 = input_block_2, + processed_1 = inp_spec(dense_path_1, "input_1"), + processed_2 = inp_spec(dense_path_2, "input_2"), + concatenated = inp_spec(concat_block, c(processed_1 = "input_a", processed_2 = "input_b")), + output_1 = inp_spec(output_block_1, "concatenated"), # New output block 1 + output_2 = inp_spec(output_block_2, "concatenated") # New output block 2 ), - mode = "regression" + mode = "regression" # Still regression, but will have two columns in y ) ``` ### Step 4: Use and Fit the Model -The new function `forked_reg_spec()` is now available. Its arguments (`path_a_units`, `path_b_units`) were discovered automatically from our block definitions. +The new function `two_input_reg_spec()` is now available. Its arguments (`processed_1_units`, `processed_2_units`) were discovered automatically from our block definitions. -```{r fit-functional} -# We can override the default `units` from `path_block` for each path. -spec <- forked_reg_spec( - path_a_units = 16, - path_b_units = 8, +```{r fit-functional-two-input} +# We can override the default `units` for each path. +spec <- two_output_reg_spec( # Changed spec name + processed_1_units = 16, + processed_2_units = 8, fit_epochs = 10, fit_verbose = 0 # Suppress fitting output in vignette ) |> @@ -124,94 +129,106 @@ spec <- forked_reg_spec( print(spec) -# Fit the model on the mtcars dataset -rec <- recipe(mpg ~ ., data = mtcars) +# Prepare dummy data with two inputs and two outputs +set.seed(123) +x_data_1 <- matrix(runif(100 * 5), ncol = 5) +x_data_2 <- matrix(runif(100 * 3), ncol = 3) +y_data_1 <- runif(100) +y_data_2 <- runif(100) # New second output + +# For tidymodels, inputs and outputs need to be in a data frame, potentially as lists of matrices +train_df <- tibble::tibble( + input_1 = lapply(seq_len(nrow(x_data_1)), function(i) x_data_1[i, , drop = FALSE]), + input_2 = lapply(seq_len(nrow(x_data_2)), function(i) x_data_2[i, , drop = FALSE]), + output_1 = y_data_1, # Named output 1 + output_2 = y_data_2 # Named output 2 +) + +rec <- recipe(output_1 + output_2 ~ input_1 + input_2, data = train_df) # Recipe for two outputs wf <- workflow() |> - add_recipe(rec) |> + add_recipe(rec) |> add_model(spec) +fit_obj <- fit(wf, data = train_df) -fit_obj <- fit(wf, data = mtcars) - -predict(fit_obj, new_data = mtcars[1:5, ]) +# Predict on new data +new_data_df <- tibble::tibble( + input_1 = lapply(seq_len(5), function(i) matrix(runif(5), ncol = 5)), + input_2 = lapply(seq_len(5), function(i) matrix(runif(3), ncol = 3)) +) +predict(fit_obj, new_data = new_data_df) ``` -## Example 2: Tuning a Functional Model's Depth +## A common debugging workflow: `compile_keras_grid()` -A key feature of `kerasnip` is the ability to tune the *depth* of the network by repeating a block multiple times. A block can be repeated if it has **exactly one input tensor** from another block in the graph. +In the original Keras guide, a common workflow is to incrementally add layers and call `summary()` to inspect the architecture. With `kerasnip`, the model is defined declaratively, so we can't inspect it layer-by-layer in the same way. -Let's create a simple functional model and tune both its width (`units`) and its depth (`num_...`). +However, `kerasnip` provides a powerful equivalent: `compile_keras_grid()`. This function checks if your `layer_blocks` define a valid Keras model and returns the compiled model structure, all without running a full training cycle. This is perfect for debugging your architecture. -### Step 1: Define Blocks and Create Spec +Let's see this in action with the `two_input_reg_spec` model: -This model is architecturally sequential, but we build it with the functional API to demonstrate the repetition feature. - -```{r create-tunable-functional-spec} -dense_block <- function(tensor, units = 16) { - tensor |> layer_dense(units = units, activation = "relu") -} -output_block_class <- function(tensor, num_classes) { - tensor |> layer_dense(units = num_classes, activation = "softmax") -} +```{r compile-grid-debug-functional} +# Create a spec instance +spec <- two_output_reg_spec( # Changed spec name + processed_1_units = 16, + processed_2_units = 8 +) -model_name_tune <- "tunable_func_mlp" -on.exit(remove_keras_spec(model_name_tune), add = TRUE) +# Prepare dummy data with two inputs and two outputs +x_dummy_1 <- matrix(runif(10 * 5), ncol = 5) +x_dummy_2 <- matrix(runif(10 * 3), ncol = 3) +y_dummy_1 <- runif(10) +y_dummy_2 <- runif(10) # New second output -create_keras_functional_spec( - model_name = model_name_tune, - layer_blocks = list( - main_input = input_block, - # This block has a single input ('main_input'), so it can be repeated. - dense_path = inp_spec(dense_block, "main_input"), - output = inp_spec(output_block_class, "dense_path") - ), - mode = "classification" +# For tidymodels, inputs and outputs need to be in a data frame, potentially as lists of matrices +x_dummy_df <- tibble::tibble( + input_1 = lapply(seq_len(nrow(x_dummy_1)), function(i) x_dummy_1[i, , drop = FALSE]), + input_2 = lapply(seq_len(nrow(x_dummy_2)), function(i) x_dummy_2[i, , drop = FALSE]) +) +y_dummy_df <- tibble::tibble(output_1 = y_dummy_1, output_2 = y_dummy_2) # Named outputs + +# Use compile_keras_grid to get the model +compilation_results <- compile_keras_grid( + spec = spec, + grid = tibble::tibble(), + x = x_dummy_df, + y = y_dummy_df ) -``` - -### Step 2: Set up and Run Tuning -We will tune `dense_path_units` (the width) and `num_dense_path` (the depth). The `num_dense_path` argument was created automatically because `dense_path` is a repeatable block. +# Print the summary +compilation_results |> + select(compiled_model) |> + pull() |> + pluck(1) |> + summary() +``` -```{r tune-functional, cache=TRUE} -tune_spec <- tunable_func_mlp( - dense_path_units = tune(), - num_dense_path = tune(), - fit_epochs = 5, - fit_verbose = 0 -) |> - set_engine("keras") +```{r grid-debug-plot, eval=FALSE} +compilation_results |> + select(compiled_model) |> + pull() |> + pluck(1) |> + plot(show_shapes = TRUE) +``` -rec <- recipe(Species ~ ., data = iris) -tune_wf <- workflow() |> - add_recipe(rec) |> - add_model(tune_spec) +![model](images/model_plot_shapes_fAPI.png){fig-alt="A picture showing the model shape"} -folds <- vfold_cv(iris, v = 2) +## When to use the functional API -# Define the tuning grid -params <- extract_parameter_set_dials(tune_wf) |> - update( - dense_path_units = hidden_units(c(8, 32)), - num_dense_path = num_terms(c(1, 3)) # Test models with 1, 2, or 3 hidden layers - ) +In general, the functional API is higher-level, easier and safer, and has a number of features that subclassed models do not support. -grid <- grid_regular(params, levels = 2) -grid +However, model subclassing provides greater flexibility when building models that are not easily expressible as directed acyclic graphs of layers. For example, you could not implement a Tree-RNN with the functional API and would have to subclass `Model` directly. -control <- control_grid(save_pred = FALSE, verbose = FALSE) +### Functional API strengths -tune_res <- tune_grid( - tune_wf, - resamples = folds, - grid = grid, - control = control -) +* **Less verbose**: There is no `super$initialize()`, no `call = function(...)`, no `self$...`, etc. +* **Model validation during graph definition**: In the functional API, the input specification (shape and dtype) is created in advance using `layer_input()`. Each time a layer is called, it validates that the input specification matches its assumptions, raising a helpful error message if not. +* **A functional model is plottable and inspectable**: You can plot the model as a graph, and you can easily access intermediate nodes in this graph. +* **A functional model can be serialized or cloned**: As a data structure rather than code, a functional model is safely serializable. It can be saved as a single file, allowing you to recreate the exact same model without needing the original code. -show_best(tune_res, metric = "accuracy") -``` +### Functional API weakness -The results show that `tidymodels` successfully trained and evaluated models with different numbers of hidden layers, demonstrating that we can tune the very architecture of the network. +* **It does not support dynamic architectures**: The functional API treats models as DAGs of layers. This is true for most deep learning architectures, but not all – for example, recursive networks or Tree RNNs do not follow this assumption and cannot be implemented in the functional API. ## Conclusion diff --git a/vignettes/images/model_plot_shapes_fAPI.png b/vignettes/images/model_plot_shapes_fAPI.png new file mode 100644 index 0000000000000000000000000000000000000000..82afe5fec1226fb59537e650b198c8e9199a148a GIT binary patch literal 25752 zcmeEucRber+qSkSBcvg6l@UoLtBjJ75!o40S*1iq_Na)Iy&@|SvO@MMWs_v@9g)5F zb6nrwbN_k2p4WZxea=EVe=RMBzIFI8v-!E^-OYJ8+LPkPDvS0d|xDpA; z_Cfp-Al;6CQTT1fgfF}A%1VioY!UyDE{YGpzw9-;reQ%sLVkewYum4Gr*#sN<0R7J zS5$1m|8`qD(5kNSP1`?FmKHle`i?({AVC#$P>lST-Zoc{(i4P(p0xM369@!}D}hoH zo?cQPf5l$)Wj*BindIHhohO-}`tDSE67bh{E7w_Y>Qh1c!|F--(uIxZ^?yTojDCFX zeo$Y#TY~fohCvb*YhksA_=j27SmJ;?-2VTU|33x)?k*VqJw|bPlUCxf75$g~*OKoG z3JQ|!xBc)94gE7QF)^m&qUYx7${=9&rMZZmGr7ls}kcvg$6M+iCK}WP7j3SNiw}*ASWov|Vg8FxlDOAtY23QkIn`5X#r}5XHFkGL7;$N3Cu?RI4_5e$m`BL7 z)il%hWf;_K&P8qAn9N#VpQ|6SS2XR+zE1jhb^cGW<)T;51z0Qa|V-_mr7jxaIR zKk!>HstYS~JbK~YjfTZ-_xwrfEEWW3NvuL#soW`!W&9lr7qqD5yED~OWbSwUEng^a zkiUP-Bj&V}ND;^U{JO8t_t5ys$J2<_Mu|p^%e^^!h4fwHpTeS|=Ag{X%)C-Ja{AAC zre2bI{qmITkI$b!ukDyGe0iv6_dP+uOBAM@B|mgjEhP zGc!wwDEDpMmfPhSi+S~oVfQRudNjn(MF*uM37n+p$j`WwsUw>=3#$o?E)M%_ZDW1LB zW!nx9_T$H0Jk(OPrNuebl8+N2MOUq@!ExU@M_G z*_n*}!0HSL^fg<|57k&%;= zlbf5H!uE?NW6s>lzcI<5*BGxr)9~}%)nS6SjT2S0StpQO=u*8kvSUkGl1< zTwQkxdhzqRh;eFVpQbkHE$%kX+rT!N>Me1`-i#Esf8px7j`5kR;f-;)92^`z$VV7w z2d>7vaNk+3@I$L?d?SZe<3+ssvgDf|vjn(U67o?qz3E9oFQ=X#J?lNeDZ8rxb5!zc3aTNyo&b zAdZJynwv{feE-6(WpM(3>+N|j&ee5{UgO4={{-WfBnBDu2n62>l zI-DREUGCdKI6x^97B89ioFL{Az@||C{d+DhW3~QntS!B~SN{5NMptWVYx3sj`;Eoxv$C0GaXa=dK`lCl{Sb81i>JS1i-u~i}qLqV{UMyPbxrE(b|4}VP zQ(SzfsLS1j7``K3jEs!W_fYnfI6GTf&f8BFwdOt^Ja*~HbXTHsynLikC{wr*_L359 z+*Nl99L2Sjk*MLa$kt1{0~yO^$q8oc>QN_lm6s?z=@{$k?p~Y!bHcT+fAj!ZiIS7e zy$SW|#lFZ(7GC_1x~RkgIn{rJa&k#72brHMeM1&#Bl+IXBO@b&RX%)lteE8TIVv%# zhMpq(ndwVZegv65^+osnn|~|5eYtm`x?~SooR*>Chn$=ob z)#VBEzxySSxCD!HbF$a2RaRHqP3DbosD9LcZ+>OVyFhp$VmhaPn1fbAF-=}E@KJZZ z;JvsIZe2CnxVvSZd}AsfRiqXB&)h0@Xuopnnfuow`z2b5_I#@cYHBSAXLJ(M^77e= zy+}Tr<=oo2XYYRZa7Yqdjd|4gp>C)u;N`uQ* zdpX{kwzbfzD{i^5x`1p<%ADLX5`u5C21pnE(#m;+7l#$uM_#m&+1P7x;Rxc6J#z5- z2KJ*DLgm&LCtmsZgfz>sHX=xGt&A`&*wtk}?#g{U7BbP1m84rzJe+#=&_2AMQGF!u z0}eq!L8Lcf;akD+A3jJ*O0u3jd6SQ@|IpdK^a{pI#i!j!?eS4g2l6};a6lJNllq1r zSN|RwvYhJa?(P;oeOiu``p}`u%-Yl&o9d*@@h?Z0X9kfZ`Y+-)X=&;HKX2Z=sjaR3 z=1$q0`}o^mj=={b$O2zisIm;%TkKw=!2h$;yf$_vBmEJWK*Iuv=e#yi#7q*Nc{ky)AQPTM9BfQl$ zHJqnUC*S$@!ZW6;*zs*^TatR(moHyFeLBm;#N^FNJPAunl;q?|H`51b6omX*y}2-I zLqkIyi_ovPkH;W5v3s7kUwn9mLBt_{+LifLLoldgMddhEtiT6LWx$S*zl zyWH_0tzCzTLCod%jNk0+Y{sn$$#-^V6l(mDB#8+k9uN@F>S1e6emSzE`O7MOrTQY9 znY;4x6d?|?Rcz~7J|3?nPtX@H54-~CU6`5iyFSrf5TcctFL;#ykxGuJ`A{W7&P+|} ztpZjS7U~9M?mT%PM&YimE}Jt7)0T6;>ued=Synr>{t2Kvn zvh=t8syOmcKGn05=Nq)1!$J1Uv(g}6d?0Qq_Eb_0x>r# zB)*K3!om^ph&XdNe)k(c91}2ektIHRTU#6crT?3I(oyHtKk@J1zqduAZ#48m!sTXX zzef61@6cFYXlN+!gO;HKo9Ahe4Cya%Dt~z0Y%+93H*%AhzqGspk~ORD6$K6E`iK*koM+Ume7Zp~YX+SQUuiuQPhl*z@yFbO4F()Iuq`6!W%;Sk+0x-+`P;X9;?r6r0~ zfS4wvImAeDq$b<&DO_&{?gvzTe$biRQI(O!H%^2TDqpDc* z<>@;M_TKFS*rt3v@BeC?%H&BTWuENHyDE-)f%Ke7$AwN==0pL~$<4q6fZ+`d4PytM zf4z}qteRJ`_`(BzVJYilk z^F~@m#<(jt-ND@eBddE5XJ~n`TZ+kX{%>aeY|Zqv*@CuUmq{Ylf4=Y14-_bdBQjjw zg|^Cz8fmTc`Nv59%v3OK;l;IN0RQDnsuR{5t5yrb%7pU;cRrTu3gQIu8#g7y$%bgA zyl?!2W6NjK5$nX$e>7U@=FK)hoTk>0<_{by3AZ)cVkF2B!vH;JYkvN`vCtJB9*#mn z%Ou5OTq+R^CG2TaLhE*Z5(4>wKH8tf5Q)DSf(}p=-4A&LOY+F(^GCu%BQ3&fz{MAA^vqB!XpSyVZuAw38 z2;I@7fJVvp4Zu`CoWx!R|3fof7q6li~T-pLo*!NBL;)YDg#^prKjW z^TqoULSeQ^r^x2=V0O~3UAtzccA6;D4kH>{jWtyziD9DDlW%@)#1;b9V83C?%xxpT zOXhWmj$PHe3Ed*Q1+WG`ayqJmIQ8z{olVI`S-Q0` zO9;DYWn65vT=UZF&35lW5M(Z|xHK}pS3OYi`FXGjpa(%(Gf6FFVEo6Hkx?c!HT6o% zxkLLfdqX9B+xqk+^1hRO?e4y{S)3NFmV6Vj$C>tlH6xyZVE?gbZL(mcCC_>)Qdn2< z3$@u&{7$pcZYDG4t_@oO!PQC1T(ZecV&MGml??r+O~Hqd*2Y9>cayN(x_sFsset6h zcJMBEzt7})94s#DX>CjgL^qf9BqhtgWj9|uR?H`Kdfsm1s+Aou=xLq8T%bnyTN~Tz5^tv86%*V3L(b*b*LxSh{5T=C(y+-}TIoiMm$nMvBYvJa zN%#Nuiv(%^mb3TPiyD;vH@RdchZos=abQKRQ>_}EBM5A z!~BliuFouNZ0P2duN*sL~{1*QgFWbyXa)gr${KbZqH+l4wEh!s+ds zXSgokKW|w3COZ1K`qFlp(~ABWalYNwHkTsIo&Huck*rc*q9;kV=&Hz9&0X7p4f5~u z*8dsv@2!-m#8R_$AQOo#b&{Th)*tZNpyr^uHJOa)NjkG+bzMhQMOqMiaWX+oO-)FC zHO<1?T?mROoN0A+d%=l5nd)&QuFE8KpgCxzePSNDxw$SbBm^SXr$4f9H?z2!npzNt zs(xKfMM{_^F^)E;MFw&Jz9YRX~9oTX81Gu%A#&J z*t^G8&R@e<{q>zUSJ*-Vxpj3&shOBcjOrOT50uhT9J}~kcA~dL6jdteOs8oHbpuWl zyQ_fHx`1wBhKRH5%3Qsu50mH_SDT5f*6@6{+fic9#`F~aWCAie(h{VKX*$Bp-?6({ z$cH^Vs0Qjqw^Y=F8>PG%h%~gp1ECE-TQX+szQJ6sX9O{&Q-7M&8P>p-=IWz#HdjNA zaiapu@S!Vqww9;z2hUPni<~Eyra2tIpBYoLF9)Dp7u< zl-}_y`TqIPgWNg~)z}5^cSaxoTjKj_SETZ`k%m8(;+BhlHJnr=;u}xp2?+^3i|qW_JpCsnVs5K<%{&Q=g(K}gNo-ZP!Mg%}h*wF~kw-KMQtN{h*#Bii2*D)6DwpXhkC@q}aTwY#I;H=TgM^0A^SsH7udD*#VF=I>Y%)v*gJv!`J2M!!)5*3xd ze*LR6I#@yZ7=H?qTr4+BDE6k@#VR}CVAhw%dc zfV2?c@4s*Nb3OjOef=MbPgC06}WT-&)1yXpg z-OM4Y|9AO5#L54JHO$M)-%)lo{^}5itsE;&f15DF@=bb*AHM#i>eh)Wf#iAYGSW%_ z89I1w?QU}RLlTM}L#_0`qpqaDghN}%E%s#I20Nh`h^=sC(7aEgV6 zg^MfYP-h_5h*pK5lp5*R0>up@;bTR{I%yLZyEv_+LFumE6c|~pmThQ_jEI2Bb*@rDjI}A<g9-+~hmHL0 z#r~Hbx*as9XARDWj+!5E2GL#N#t+o3|WJpR93gqX{Nl0KPfWo@b{89B-YaD@^ zMdR*Wf~f?UOfO@JUO9VBv-uxGJPN7f5UzlZ_)4^k?6RJ>h&*UH-SZ9NK$fDH$5=~B z(ix3hvv=<9?i1yC!@5o@Ws(t5POJBS*M=I(BD~Z~zRgzG^jO8Lp%idfY&UYFLUDZV zQq-}WF%J(?oz}Xy{Bs{I5JJCvxx8?+4042VOLA2s*aqUnQw_ZL5t&_D$|<_|NS~O< zmOe=bur3x#1$if3VrRcs>bh6VgHmMKC8y)|Spmyt5X8h!wWdTZ5TT++sonPJ&Nq+% zZ8wckMlCWFlzSbG610#k2FLrk=-L%)iQBiMC&(T@evF)Y`o`N`65>eE!!Lb(sR-3o zReqTa>E%em;FK@+ZYYRDrtaTsyYag-jxExrl%=Y_Uxq`xsnFI!XDCQrhgwWrN{W&} zz_;Y)_`>SYkUq0`nJ3+F1`&iet1Kb0SI*9j)LRTf24@aIXm?LqSy`zvUSE+B_d3Em za6ree;T}Hif;}gZ?5fKyZEdxlA8i=pBE-bTo_r_6@#oKv zASB?h{E6!Bk8r^ZyBR`I1oJbTG7Z+L?3P2&%cYsmibGpYT?!=pX{=va0 zP5C2{1=6)|fq-h3@+2-~4L7Ii$CIwa`9e9Cw+YQ{4k$oFI3vrr*?*Wd;N@E^%;Ati zP#R(jl5*lc!i$HvZvv*RL`?NMZ_$aJWMiu_h7NX63m$H>}PCnzDwRx#gea;pIgq=KA7=X1C|- z*ZVMH?B5gbUH*0=(HSVu^1Sgt%Fg?Q6s29MI}j=-WO(=P-AL^7;YB z5+k+U`SM;%vWA!kB;h&f8aGV)U|5FuV#_b+;Q)mTAj>;@w!|$L##AG1#uDT8iK@2A zX}4IVLvGhNTGMpZ=_RG4YIqb_?J{Et$q~;oPF}iHV81k}o#C)zzR(555>P@^+dn*9 z8=`q+8Iz`%2hKnQh!I^(9qls zV&^QEQ4d{udb+s&g(D_YUmOBtAZqkx`yPJf5fc;hkXvGx0>_vbhQKJ)g3vbdXL@=X zf_6>p-7H59|H#N6d9Er46PNY)A3SnoTx>Plc z5S~>;Fn3Vue~I{Dw=lMsnO34dQ;hWSSW_aRRSj{1cEAua^b(~+*(>_D&HaBc7gb3g zpz32KA&Gr*;8FD98_wcdPyeo)g>e}qZLSYqOE+%j(D`0b5f3?5C`7ags(Kl3$?^bs z9${0rdZOzoUm?+>GT-mcG+TouZr7>)5&Mf+_m;9=4w)@0F13Q?pI>F)zMZ&qNoy!_FM`70OEA`?%**7Ts0bG>s6^g| zjDbv%nXeRVssy<)$zYV3up<1E<77}NySUOh&ez`FjC>E@XuK1acI%bYwfhr?05#Gq za0w6&rhwAYQhg;qCdS>*pH}Q>o?RWv?hq-7NDT1oqFZlaE%{ZoDt-av8ex?{Po1Kj z_iK^jbA(DPIEjQ-1AG1E#l;8NChX^JepZitBKPu7?WUgNyX1==M zH@vN<@W$5w&($5ps!3|*b`Th{)y03vaF}Z<(o&d+kp@)JGB7+eQtDqQUTD!;-{>(O z6!Uof`Ze@UlzJ=z0&fDU0srehwp&m*Y%H{zA7rCr{tmh**=a6R$Tq`n=FOW!?IO;5 z{dgZVCER&J(U!MkPD0#yea2l_wmnp3h}jp?ZN({}-dMs2#o_ZqNK$H)z8Xq?l#5=6 zd5B!_z}Es%Fs^R}9upKQ%FCU$HkM3_77h-x(lN3a*cCa|M+z@8v?H`8t<_E``UuTd z5@$xmsPfG@y`~7$#t-}1=)Ad+3VeBi6RX`9HkNvGitJ6;f?&X)X@HC5lxqT>-puTi zcEN<f4{$^S(@ra7$z1s0MIEdm{vvF;l8iW5s12ADl%|OPEH1S z-uPT9qGWgbeWcnkzpG07V@R3Veb^-PKFOyd&GX;;`DHVxL?lX6PdZM`lt%0skdZjI&=6$~Y6|-`NwsWJ^|8>_ zk0!4ZfsUV>OMP77x_ZWH^e1y?osyFtN4!E*u-3^C|HfI{XzL!^3UfK|LNDSm&N3hB zLTG5(zkm-hAdf>R3mAnIav-itLp|ej!B(JN%hb^_kl3Q z%>eZoO0I11>2~OUA(i{?hdliULDfM zQ$51fMFzFKl_sYRM&4AzY}2eDk?=M`)VY{}TSYYR1lv#wf6oYLW4K+;CK>X5FyCCa zbs}@kGU;p?d*cZUrrn$ymE-EX-_n!8!9jRs0zyL*MYhde%)ky%U()D#^|vXJMpS{- zb+kDtB!*o>L!;{|0gi%!b9AZ6jLLqT-Ovms?8mb!RF>UC*Ry%_E2*s5ApG2YuhF$} z>h)b|j@dO|A0Nv4k3i@lKp;;8Jlm%WsKhohYuhu8v`mCfoVe;#(UE5nzdr9zX{Zo! zQI*s$IG9z+08#E6*G&<{@!!A8qZACs2jRdI|c=;#4oUwVoRpT~kHap8S{WA|jE!yK=>c zyaQ*oZS|g&B#$nv3R?WGjJ?}Z4@te!B+sn>*{3J^dng%e>EKVo@Qj?Q#YG;U8lQ{4 zFgUV#Ix1mUAsQ9Z%RK>}l|__gji`hIt`i*%89avInr1*HRfmBW%gXl}t5Xzx_TFhO zs0rU_fKv15#wO{tHpk?MD3u|F`(ga1KQfo?Kj!3E!?)1vsGL*fV=zsjD0(*bQ8<$W zYs`6r>RJdxkRFh|M_W>k(pT>Z_hAquP*xxgyJfK%uh5sQHzYx5Aq`b=r5d9?y<+J` z)LBrp1n{o>cs5~iS$G3@FKJ|ULz~ZG+3d_vd_zihIk2FZ2T{A`m)W>UN58tfT-i%k z&9_22TLV&YZ*wmNm@y%Z!ft6mioiNBl1b8TTE}C+KfW*zsCM33(>)Vdk{f>B!1J$Q zuWUw_PI2$aOD6(6*S+{lR466pr0?=#`j=PEmMf%FDWVt{_U4u4(4RRO>Tn3P-pX*O z2f%~=#N`bpiZ*q4YK70EMQeLw&%{v<;2Dhl{923xV#5g%| zaXtPx(IX>wi&LI*jV22r##Wee{J0_ImyY_nIxYbLo6c;LvQDN%gx^J>#EZ_G4o=bC z*`GguCMqxjtE$)Bdg|R+_jr+y{fUpFprH6YLc-7=pdjLxQ}vKtTaK9kO&X5F9J88> zfdTc~`1$$Qm^yfQY^HC!FA^O_i-isCG|KP*9AjWOz4ggmR$!$P_MU7r(Zb^~5WYWd zQOT!gvVEXx!1Sz5_*3a@T5NWq;Aq{2M_p}rmH8g%h?=7YQ2TBb7yPyCFWseX&8$GV z_U0cO-r@hM8&CbzN{y-8W}|PPIRtR-o-G+hnNYDU+Z$$r442ps9#Y;0(rSXEU5^8X zTRgLc6*q~82lk6?8N)WCBSPV@``=K32Wc$C zuBZpH)UUd4->z>qc9k@y~2;*#q#1Ud3o;#mv;0qKoh#LujBLmZ!~bRTbW>N z_Ao2k-i0k8bz-keBjn5ILoh!eHPSSoKOo9sR$AS9k4y@1q~%0M5Ye0zWlj>u=udWe zrtgle_Km)K@Wk2J9NM>A-FoL2>1lrBZJrg={olBC3P;CuF8@(Sh@#iA`Py7WV#Wx{ zmevPH4TL%jl|=Qcx~!~BHC4O)e}CJh2;uMoykro#NRBJOQ+0N2P4WiW@kEy*Io5!6 z<})NNyZg^vNp#gc1$FB`A#MK+|L2#jJg732Rx7gxMyYmH{w!rZzj0_ii>{NHi9m5Z zjs}B&a5$m}oy1{4$rC<81E05o)kqcblXp$$8l;%Gdib_|VZkPL;d|6|P24l)FPs)0 z2T$Mp3GxLrRwd;c^zu{Z&OL^KS3zjTRr@1KKGlyZL|Pc?5VUOgyaff59uJnIBFKe* z5Fcl72kQ>zlf`v;zEQ#IV#GsN56t3s8qXLx7{5v3zR<-=m>ok%io9ewdV~2pF#5cm zFi|tph7*Kp0G1Kh(v)&eu|Y8*+N+trgU+LsIQ83c4~aJ`$zL9nc{sI@9*!3R_yBg^PE z)EReEx(^7yUFH}T-GvLqYO;D-4=@@X8$@i|R@-IH-+AO!6uinMO-W{`Gc**)LBJxc8eP@MsWCq}yvjF0$GfB!kV`FoU< zlu0JD-$AH)TuW=^^!x?dkMI#kGkubo#OyRy@6#~EVAdFzB%bLyE%&>Nt7IBp&myHO{dtsb56Ib#VxMK^7|qgEZmPLWcwjXqKZU6h|(2 zc%U+JrBYN>6ur=q)gTo3<==s;n;$R5H1O5CI zf9`gtWZDFVCVnb(HLyybn=gWWwZ)-?>cV-c^C4Wp=&|58=^z&AP@S|=XqcFELF!o% zeIi`VWv8ls8~Heu(rzY_XcYEDYY@zm>&3odVKU-b4_f&`<)Fi2`@+B4 zrvVloX4HdxQL-jY3Dk)hVnoC2%p5`6X5jnYObHn6)1iNZ*`)?R9L(*&Iz6&|E6qh9 zJR_PrU8z9npCrh0Dapz08+nT&REpFuFwhb;(E9qGXxGPgIiC6N*W;jk4jNrtQu4`U zS9`M1*|TOaDXiDJ(*n0ZxoUjAd%u6<-{+JfH$l#=$DO0jdf09bk26X@Br=$o*WVr| z*!@kfKnoM>mbKjQaxon~yfT)QR=o?vB>ow`ZLZc%xWOtbE0GT%PQgIEINs`0gPtUY z!-w19T61a#ordp&GoUvwxxDx3xgaHJ>g77dHX~@mnMU1eqx)HY zBFK)p8fZR(k9J^3p=k&}&&0a+p%&C52*GlcjL|W~jv_b&3b0|AoYxi8TDi5dYoDZ% zHRT@sF#=K&H70cfSlJKkGGI$A)P$kNc24*lIw!_hjewb86dLgj@zg?F^vd{77e`jP zH@E8rmFEMhBM&k-jvZ(KmbKMPYgLF8(jc7yvpUC^3m4J*g8t#*ps5}ao9So=&JlkQ z2}RST)J$?Fhf$*JD z*TS)#uoq_0D1xYPrP^BS-2#m%)+3Rv{{6vAFRv`c5eU-)cJm4CN~cbp;_LWa^h>PC zx^#sBwYlpu=yzT#3YrEu7dTuiYim!&Ag`F)0tz(MKxJqw7d-8XzJ-BRIHgX1n0)>M zi^}jX*Tv7DKR>D_5hJU36WU{7wdTF|Ga6q!W1z4)EOlE0OAN2V@=2te4j(xpBfg#% z+d2-0hWfyPfj@ulpFNcOv%a1iB?QD^Ip&`1<;S z-G(&bNpxH5K(WJTs#_=1C(X~D2$4uX&M^MFjCFMav)WgX!X`=v*~NfRkr zdFCwCzm@rc??PON&Oj+l(0(VR zKGrMr)B^2`K-O_^K>jNLX^|=l)6VGVy`&KyZuH3vOin(!a3t>&p7Y(ickpcfVTV&thHE{s6{bXch`}YS)F>`Qy#9ZRl7&mmMD35z!O5nxzv9p3UGqh+> z1xxXR=8p#1Oie}~CM=eLWc^r(71o0a6oJCM7B@-1J{NWfTyZ`u1V;+!Afg)2+c|PJ z(CJK_^bLPnOF4;KL|=@U2d3QB&Fyg&8dp&I=l{zaZB_+Nfr3$pJ=w7;fMxgdJ5U$2 zp;Uc{EI`)_^x-ori=C#Ze|tl-#}k$Wsjov_2Soix zBk8#>4kVgYcyPqn+{$65h;;8+G{ENR(uqh)ZRwyQXSxj5t~il6*# z`^ti3E@JWfb@SrzOQW-g!0%dHMPM63x>O-O2Pb(AeKr_tx}gq5yHM?oCay#^#79WQd`U%!3n9A*dgcaPYZnXg2s4853k2Xn1{n&yB)LQX0yotV%QiA^b5t@ZWw z_bwRf>r-KFH-?nJ0?~NOaY(+OL608olaq9L*Pc9if^mas=}H2COTL+2nsG|-7!A!F za3#Fp*McqowlyS&36EsjK?uO+oSB;om{?v}Nu-CLIs=u+rzaG9cJID8%J4g!mSO#~ zy4Qbr0T|W+w)ZkONxO|FYv-L$JLsz-!{4>BEnWt;mB#gLrj$-;U05R!2`gO9ztF@1 zA%6{J>ITU5^`c%mV`F2C_|Mm+ItTTXMvXkld<6R7;DhK<;Lg*i(m$`7aN}vbpq?~9 z`cCR&KMWN1kbQjEo?(9R{COpLt#02!@SXGMY*G%1{qSLSr>H`-&Z|&p?^P)Xz2O_2 zycTFo1pf&yj|D*kWpKDbw)!iNs2ar(*zQGygu0T$CY}cc2A-Uv8P!q#=ORz(S*%`icw%9ZdHNwbt?leC`Yci_2C!alvOuZFb|9mz6Dsye zx?aKXe|$J7MqH@?YQzJ5eW#&Wad?(?tBLw*6;>RRIZ)xw%+BIzASWS`k_QqGl%KeS zgcllAWEm0oiA-^qpG8eM5KXwBAG7P(P&s%50`)kE?FGho(^)6q`1p9Ql)iIxPItfk zo%Ep+dqKkd9jy!=_aJnLd4Ol4Y5@F6dEy^cBw)XoOm!Dbk5J4fXbA|_?LDE|Txl0G zIv`~K6GCAp6`w~D0Ip1$)G7K`2DMW6q4VR@5CzzmXmYUQAaK-IR`yjm*0ylQzf4ek zFDd!{9JDf|E}`3o??=q#%FNl8S69K|!3r}?vyaIA$wmk~7Q->GM(5{IUgMLxXh)m& zeM3dy;Zg3fA=|q*ok7{}1UtZ~mk$`&WYM1B*ZRh@kj#||vAMJ(ksvE@qFN)I z=%;}lW(=`W+ALd*Xv$5b~`HARz`z=3WfR7VC% z=d@~oVOTS*(7e!GZrYXml}<=lcxj@OA8;nJOn2L{#|y7YwE zYrlPazNfS9_w0_h+IP<$T4vlyO{CXE<(tMq@`G0k=BnXedpNJnN|$u$9TW}io@5Q4 znvEOU(h@@=*S>>Scwa14RrQ*f@PA0Vm5f^qVpfOqfUpNwg(L??M))7Loefo~T}@q-X=TI5WSv_`4bsRz@N==U!$>ZwA}RF2RK{0o&=% z7s036JbWF$GSXYSXp#uAs49%87dM-H)|R;$z84zrd^9{hg6a_}wUK8tU1OE<;jk@2 z0OCWr@b1u>o{!j4uu6$is%laTbgEZr%D=&CZ2hXp20W!t-GEUc3U}St`JNl{tv@ZJY|4+XSwGfr}_C$JwECBr>+>*s-5hcsOgYMj796 z78@U=*4#YI`6+4;!V1@=&C&-;wDjJtE{5(;D_%85h)*Du%047mLb)C0X?!D~wB|o>zQ}Nn051}lN zO8xUQ^AN)Sj-_rB*a2T9S-t*wI7l-~v4}BA# zJjE4hGyNgOnwpx1Snifdf~k}!Z_94qu>dk0IS_iPclNBx^p0knBYt$ z+B=8waNS*9=#HTz(w&L(vj;Dc&}!nqqsgv?Xur4KDI*wAHUCSN8+AR2yy~;5&6}HL zTqHkPBjiH@0`!awv^z*xJZO}5{*P`Xm#d`YM9$VKo+LWqH1`Bbqx^j38yy;vPgvZN zmKfRyt{$&m4gNutBKGf*Dxzbsb&*UPcIy6LXO1;hB*2)3+JaV$*qEJFw4L7kOGQDM z{WF@D{^O}47jM71!9qtcx3Higs_Fr{*~ZOdNag^3Xm}2=Y9Y33OmtwuC8VU-SXt3g zsrjP=vqVMU<;6)%8Oy@$Ey|>iksz_CdrYZl_xj5u-m!h{N(I(PNEjwc^eCgKo}cX2 zQoD&aBVL7o0F30STWE6*g>(%T;$J>`(4;VW8aOH_)T3LOj3{?b!HBnP@{Dk*N8!o% z)L!|U?$2{qM_srEzQ&4ue~g zGs5irkPyk3?H2!@s;LR$jDKYE#RUKbZ65?OQc^0hLNo_c^sut~i#a&Z)56{FTFIT5 zu(G`DTXreB@;XPvCF@{b(C({kdOp+*Q0h%v&AxbQK-PLi7%`qK2eu;k=L}Zad&=JI6B)tD}t|z@Qk@&(qXwg-H(|r zhUYwn9TXq2-=N8h#H9v%DHCw@i5l(`7<+w~N9^iVw9hu8b_xtc-v?}l)9|Vd{(7-T zfO)^5$3GiUMUUfR#3W)6t=I+QCJ)qwl5>wP?WYzFtQByG$5#(~zwTIUGP zTyI=CK{WsGxtX9>;w$b+MbxkhqWJ~^b zMcdKx%fXt#EMgab8S~YANaJ?YR^;I-hV_vtsh|g|=&;?B&ZH^FdlTC^NdB2)omuL1 zK;4fu#F}`4Kx_j9YZ;#}5>G=QzD0f)TKfKI?cH^IEAG}kPxkaFa?hCDd~)A-7@I5q z0Qf|RI9}xBl)d#1DIdjL59+#-t<5AI{h5KXSpylFYWR6xR(|{T9vKY)q6_`(Vx?OT zrY4^!B}|HXM_u18{pdKAw5#i$r_cJ+PtrE>eK{l{$$qF>`O5dUS`sc9I%>^CQN=@_ zf5v2IdDF!{5l^~3ednif+ByGNH{I-b=ErGixRRm zhaNUOJpAD#JNqb@o4*b3D?s8XYis|Fl%SZvPl9n1M6?`!tDmA}BSw)|_T3D5BBJ z5TiLHN+l&LbM0Ci%4HhE*FHYD!e`6Mlv491+CZ0zmxSX~Y;A3AY?fQQ8MOHm6BCc} zJ-iocb&hzm*MD^s{~Gk_3$`GXmcq+@&s#HGK_!4DSWU~cnHhk$>w$*sd)yR+5c?o5 za8ST|Ffb@6G>4($e({3$p>w5MvwQX*8ER_JF^yW|I#pQg)uegJ8y7MBp0U9ZMOW)a zYso}*=XTv?BCQ6xn@=-EgFjcv7>M~-EBV3GW(*^!*G)rJDG!}J?mJviU)sZ7(C5Zm?Aj=%Q9?aA~4%b~DpqC*!GAkbAmH@6Lt-5@jt(qu=QNhoU1iOD0 zwx;_*3fs4T|ALJWv`_EB9r*mb++P2_WTTrwhyjun??yOrbm4s&0aKb(%;S!<{`SHw+96uw<&Ls<_+Y?%jbb?iP+H$3+cQ z)garI@}s73(OwuoQd3nH%84F$4?* zn;lmW5r<&$yLJU6fPr08f}(ZTV=0TY8GIHxsGQJCvaSyI>|MlqFokk>1ZnskcycB= z9JKsp#KpfCOcirVKY?RhX;eTSPP*7>TwgI(eVelp(OW0`VQA~#4sBz^X-OH3_&5vN zcm4fw(aT%o^8*6{KR{f>e%Br^0nM$$v-+YJmLtwM5T--JAYy?WvS1!hGtTLCD$p1N ze;K4iPJhOdAs#A@(lY4%Yjto{o_cIV&W&CHz*GnyibJff2$`a)sb@ zb==~C8j5?DoDR}?z=5w2E^zMr`Coza6P>NE&#o*+3$Dsms08W->A|f_$KUW`8G+r=sG!Kp4tiHvQ=>DJh6} z1RLKhM-&p2@KV>H>I9L8dJxWU8mXa)<{8=^$E9wi(HWb!V3muA&T_*7|#>)#JB`A%y3R)hbxB#pG4J2ki43~8Xr_L2 z-*@;fnT3S~iop>^BjWvDuxuW=73R##%Zr|F(e-}H>(V$v2pAk8M!#(oWiYl=>Xg+b zL@%7eg#GfHi@FXoZmo;NyW(nDww}g#czXwzg4JP<`{zEg_1?`z{>sB%hvVyHchN2Xn_YFhfTwc! zk&&3#Go_Dso9ygGH+mgkvAh-t5~e05DX)b`^E2zQrH4y5BJC$

o#JAu^3J3S}sE)$#u@_+kddno*&K+$1k48 zGY>Q0@8|ozyl%8!89U|yj{~XXAV&=iLA6_GmIY{@WZb)Z*K-}bIRDhp2nP7$U+CiE zf)KsV#ZBLMcn}Iou_ZEoP*DOWy`dGL*@Vt^xzNYkEElPIz`ga#i|8m|;)R?@t%P*o z_Su?2L<*M;KnamTyncE~kTll*ATkZ0eGBwxj1e-3JH>DaAsr~d79BAL1ok<56wNae z5^!+^itzDWHB`n0+`d_r_ka@hku+YAG z5P?N4dy$1e9|wDe#3a2QuqH5fEkAtt5UFyftWnSrdr_7B8sO&DQgbr73T)>i0aS6E zBw`HYxjQ*Qj%v#7#G&8`P0v}Hm^eL8uC}96uXy#1RS-8Kj`NKhi3@;#sng8G{*J+X zZ)kKvRPYJGT0RQ|XMvB}oGmFqi1)HGD+sr~a19g@xOm;Em5ECf3wV#H#KuISC`xOV z`Q~9r1={1&)!B)`f7GmI{!T?g_L2#u6P>J`t>japJM-J;KgAw-*m^JZHz^1rRp8w1k8m{Z;r#isg?Q*;E!w zSV;%VL1Yn4E>s_gmKVnzNw7J3^k+1aQKwG=9gL=cr#KtbhV+dj(=ZbxJ7A<=dGY;5 z%()?I7jyC40%Rhkta*Qu>gtX;^06i=-ZfAP-u~un1H_8K?C6h0l#5h@NfhRS%AsSw6g_n4-shA|p>B_6kqP zY-&(EN+gD#aP^lCVt-)9N*=%phP(Yyr!BV)* zp!5GCK~~rC%8H6ZV7vhk8DGAiMB2MIlN|h0Pm^FUXPc_(aIi%%f@(*KSMQnD?Z4V# zOCO(@n1C${Qd=f!GKO|O`vzc!aY^|Mj=Sf&{!6{5Qw)d&Z|6=AtW7RmG~d?)Rt>(G zPCaa%*lhxPloU;s<^=#@5xBTZmFU}5Og{Ej%x-mzcYSg)RtsrPGInSK4Vgx$ILsMD z?)9^&R$=}+yGDZS^n_pitgP44+UkZ{g-5MlEk;G9Rk1WmVV9n*ruYRvG)nQsy5;wA&ZjF4w4q9W7&>Kjb7QC0pn?X|8P{T3+M<-z9oy;X zMMU|0A^~xMc1Yc*Zp_KgKb8B`PXcO#%QX%i=a=41Nr8R6oslnPQ-^_WTc83{v&q5E zOlg@$&u{h;+&WAMTuxDu{O&oXIW%0rDR{b-TtjndQwi~R48JTql6B!3y)DD6%~x|S z23B-|xbGj_q;joLfDX{GiGo~6>87~^cd76Xb@d<-e!;qzFJD5!FZFHM204{)TYhLg z4H&dnKu%s(5D6CB=SHg1$xO7MnC}J#2GF4(GC~FJp)J0S4L?715@dAzxXR>KyoHTg zhreVlc=01uG`1iKDy38YuA6WjQmsgp4YU@1Bse=FxLF~dk^0*32u087mK-v64u7v~ zq^K$U-d`~(tJ@Ww+{7?CI(j(su`mm3gL?j5i&qQH9cAWmJ9qA6Fc<(%TM8q2BPO{; zz(E1vi=HG^?ye^H!4N7)00P=!d+XLM$iNH@43Jn*%oE$Y*8psknTk@Q-nTFU&3@*n{onRd*sqrC~Bei0IwMq_ezn=t;mdg6| zh(9v3NOQ0WOWm#Z)Xi2^u9L#^Nr+N!ZD+;yegaCQLT0YYyOI#)V(XtbGFd8p+kaig zF=k2*3hw#A1T$t+y-lY>zS@0N6^B)8ZN;?8aM!DBtxNe<1@UYzCh+fjk^;{$4 zZCkfi!1u_Xg&Uh0apc;! za&iSu{a+UNysw|0^D@qWklf!VS$$qQ^LQe%p6FAb&D%z*2{_N}-2+)2wKrBR&-HdE z(=J@{>U!9Py|P4%y|KD7(yw3GNf&jT_WRxDUQlAx;6cw6dzSS@!s2@_{nnWbQ$FfP zQGI5nrmH&}nz~}t0OLtJlIIe&dvc~8`+z`LV8(l?&%RA~#cjf-zt=uc5`U2%V>wq@ z)Dp4b+qr)2nf{(QrNk)lAtR&qC7T{iW`!(iSC#ku)K?bVq|B1z$Btp^=d49bQ7E zjGQFnf7{JouI3DcueHr8U5n8MdinmF>TthwNo=byUn%rG+J`(5$Qh(k6}a{=?||+4 zITmhY;;54i?4VV9{Cb*4wiuF2mL~4o2_ClL}pR;|o7bO)Uu?K-zpH zYXW0n>7oz0lYzzvW#}DehLKLP#&N7Yx64cufqNXFU&8N`r&Je+vr(6!{0t@>iH)wru-8^8SCHM>54%dmWjSsb6+214*F-kBVwt#HOD4Yh!pTCebc%lkO zhvH2h#Y8G!Wq#*TZL1QFTym@M7glg!<#1D*cJ7&MgPQg_m2W}gf;o98d%myQ2^?6z_SJX)0g|~+&Hw-a literal 0 HcmV?d00001 From e9cb78eccd6c18e78d6119f8c7a091c0ad81edca Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 15:59:55 +0200 Subject: [PATCH 12/32] Update R/build_and_compile_model.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/build_and_compile_model.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index 0d84efb..5be5fb0 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -312,7 +312,9 @@ build_and_compile_functional_model <- function( is_cls <- !is.null(y_processed$class_levels[[y_name]]) && length(y_processed$class_levels[[y_name]]) > 0 if (is_cls) { - block_hyperparams$num_classes <- length(y_processed$class_levels[[y_name]]) + block_hyperparams$num_classes <- length(y_processed$class_levels[[ + y_name + ]]) } } else if (block_name %in% y_names) { # Standard case: block name matches an output name From f2e8ec9b555e1df550a4173c7f1c7d7fd9738076 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:00:02 +0200 Subject: [PATCH 13/32] Update R/build_and_compile_model.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/build_and_compile_model.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index 5be5fb0..bc354ec 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -319,7 +319,9 @@ build_and_compile_functional_model <- function( } else if (block_name %in% y_names) { # Standard case: block name matches an output name current_y_info <- list( - is_classification = !is.null(y_processed$class_levels[[block_name]]) && + is_classification = !is.null(y_processed$class_levels[[ + block_name + ]]) && length(y_processed$class_levels[[block_name]]) > 0, num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { length(y_processed$class_levels[[block_name]]) From 8267e804f66b1b4f1df6a9b1f90bdb2b34b5319b Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:00:41 +0200 Subject: [PATCH 14/32] Update R/build_and_compile_model.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/build_and_compile_model.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index bc354ec..44ab7ab 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -302,7 +302,8 @@ build_and_compile_functional_model <- function( # This is primarily for output layers that might need num_classes if ("num_classes" %in% block_fml_names) { # Check if this block is an output block and if it's a classification task - if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) { # Multi-output case + if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) { + # Multi-output case # Find the corresponding output in y_processed based on block_name y_names <- names(y_processed$y_proc) # If there is only one output, and this block is named 'output', From 01f2612b7a71492fa20d000591ba4e71eb99f4e0 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:00:51 +0200 Subject: [PATCH 15/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 43ffa27..6e2579f 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -54,7 +54,8 @@ collect_compile_args <- function( } # Handle loss: can be single or multiple outputs - if (is.list(default_loss) && !is.null(names(default_loss))) { # Multiple outputs + if (is.list(default_loss) && !is.null(names(default_loss))) { + # Multiple outputs # User can provide a single loss for all outputs, or a named list loss_arg <- user_compile_args$loss %||% default_loss if (is.character(loss_arg) && length(loss_arg) == 1) { # Single loss string for all outputs From beb2425bb9a3db8492bb3f5abed1072f48ca46c2 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:00:59 +0200 Subject: [PATCH 16/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 6e2579f..9feba0a 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -58,7 +58,8 @@ collect_compile_args <- function( # Multiple outputs # User can provide a single loss for all outputs, or a named list loss_arg <- user_compile_args$loss %||% default_loss - if (is.character(loss_arg) && length(loss_arg) == 1) { # Single loss string for all outputs + if (is.character(loss_arg) && length(loss_arg) == 1) { + # Single loss string for all outputs final_compile_args$loss <- get_keras_object(loss_arg, "loss") } else if (is.list(loss_arg) && !is.null(names(loss_arg))) { # Named list of losses final_compile_args$loss <- lapply(loss_arg, function(l) { From 53514cca24ee93222af232c1eb96ffdf8ec2b720 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:01:12 +0200 Subject: [PATCH 17/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 9feba0a..7b71e76 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -61,7 +61,8 @@ collect_compile_args <- function( if (is.character(loss_arg) && length(loss_arg) == 1) { # Single loss string for all outputs final_compile_args$loss <- get_keras_object(loss_arg, "loss") - } else if (is.list(loss_arg) && !is.null(names(loss_arg))) { # Named list of losses + } else if (is.list(loss_arg) && !is.null(names(loss_arg))) { + # Named list of losses final_compile_args$loss <- lapply(loss_arg, function(l) { if (is.character(l)) get_keras_object(l, "loss") else l }) From d2d1198e2ac07425d675bedaeafa014f8d566ba2 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:01:20 +0200 Subject: [PATCH 18/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 7b71e76..21b9bd1 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -67,7 +67,9 @@ collect_compile_args <- function( if (is.character(l)) get_keras_object(l, "loss") else l }) } else { - stop("For multiple outputs, 'compile_loss' must be a single string or a named list of losses.") + stop( + "For multiple outputs, 'compile_loss' must be a single string or a named list of losses." + ) } } else { # Single output loss_arg <- user_compile_args$loss %||% default_loss From f5156feb13967b40e6a68f75a1fc34e68d2641d8 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:01:36 +0200 Subject: [PATCH 19/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 21b9bd1..78f70e9 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -71,7 +71,8 @@ collect_compile_args <- function( "For multiple outputs, 'compile_loss' must be a single string or a named list of losses." ) } - } else { # Single output + } else { + # Single output loss_arg <- user_compile_args$loss %||% default_loss if (is.character(loss_arg)) { final_compile_args$loss <- get_keras_object(loss_arg, "loss") From 12835e2b73685b077611c4e9a59a069bcb66c8cd Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:01:50 +0200 Subject: [PATCH 20/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index 78f70e9..a83e8a4 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -82,7 +82,8 @@ collect_compile_args <- function( } # Handle metrics: can be single or multiple outputs - if (is.list(default_metrics) && !is.null(names(default_metrics))) { # Multiple outputs + if (is.list(default_metrics) && !is.null(names(default_metrics))) { + # Multiple outputs # User can provide a single metric for all outputs, or a named list metrics_arg <- user_compile_args$metrics %||% default_metrics if (is.character(metrics_arg) && length(metrics_arg) == 1) { # Single metric string for all outputs From a00138553980c2608132f7f0406e8788514ca684 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:02:03 +0200 Subject: [PATCH 21/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index a83e8a4..de54432 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -86,7 +86,8 @@ collect_compile_args <- function( # Multiple outputs # User can provide a single metric for all outputs, or a named list metrics_arg <- user_compile_args$metrics %||% default_metrics - if (is.character(metrics_arg) && length(metrics_arg) == 1) { # Single metric string for all outputs + if (is.character(metrics_arg) && length(metrics_arg) == 1) { + # Single metric string for all outputs final_compile_args$metrics <- get_keras_object(metrics_arg, "metric") } else if (is.list(metrics_arg) && !is.null(names(metrics_arg))) { # Named list of metrics final_compile_args$metrics <- lapply(metrics_arg, function(m) { From dff4354021d7a5d387c287bc70366c66cf35fa83 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:02:17 +0200 Subject: [PATCH 22/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index de54432..ef29550 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -89,7 +89,8 @@ collect_compile_args <- function( if (is.character(metrics_arg) && length(metrics_arg) == 1) { # Single metric string for all outputs final_compile_args$metrics <- get_keras_object(metrics_arg, "metric") - } else if (is.list(metrics_arg) && !is.null(names(metrics_arg))) { # Named list of metrics + } else if (is.list(metrics_arg) && !is.null(names(metrics_arg))) { + # Named list of metrics final_compile_args$metrics <- lapply(metrics_arg, function(m) { if (is.character(m)) get_keras_object(m, "metric") else m }) From fa722998f8371a9925e98ac3b3c8bdd57934f970 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:02:41 +0200 Subject: [PATCH 23/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index ef29550..c5e8ab9 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -95,7 +95,9 @@ collect_compile_args <- function( if (is.character(m)) get_keras_object(m, "metric") else m }) } else { - stop("For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics.") + stop( + "For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics." + ) } } else { # Single output metrics_arg <- user_compile_args$metrics %||% default_metrics From f5ebf56637f3378bf82285863a4d125d9ba07ba6 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:03:01 +0200 Subject: [PATCH 24/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index c5e8ab9..a358f1a 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -99,7 +99,8 @@ collect_compile_args <- function( "For multiple outputs, 'compile_metrics' must be a single string or a named list of metrics." ) } - } else { # Single output + } else { + # Single output metrics_arg <- user_compile_args$metrics %||% default_metrics if (is.character(metrics_arg)) { final_compile_args$metrics <- lapply(metrics_arg, get_keras_object, "metric") From 94d70bc19025bde4ae322f96cd712f645fa15763 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:03:16 +0200 Subject: [PATCH 25/32] Update R/generic_fit_helpers.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/generic_fit_helpers.R | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/R/generic_fit_helpers.R b/R/generic_fit_helpers.R index a358f1a..dfa0e1b 100644 --- a/R/generic_fit_helpers.R +++ b/R/generic_fit_helpers.R @@ -103,7 +103,11 @@ collect_compile_args <- function( # Single output metrics_arg <- user_compile_args$metrics %||% default_metrics if (is.character(metrics_arg)) { - final_compile_args$metrics <- lapply(metrics_arg, get_keras_object, "metric") + final_compile_args$metrics <- lapply( + metrics_arg, + get_keras_object, + "metric" + ) } else { final_compile_args$metrics <- metrics_arg } From 61f7ecc68804b803e73180da018efd1d4f250b91 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:04:02 +0200 Subject: [PATCH 26/32] Update R/register_fit_predict.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/register_fit_predict.R | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/R/register_fit_predict.R b/R/register_fit_predict.R index d362ff0..96deedb 100644 --- a/R/register_fit_predict.R +++ b/R/register_fit_predict.R @@ -156,11 +156,18 @@ keras_postprocess_numeric <- function(results, object) { keras_postprocess_probs <- function(results, object) { if (is.list(results) && !is.null(names(results))) { # Multi-output case: results is a named list of arrays/matrices - combined_preds <- purrr::map2_dfc(results, names(results), function(res, name) { - lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels - if (is.null(lvls)) { - # Fallback if levels are not specifically named for this output - lvls <- paste0("class", 1:ncol(res)) + combined_preds <- purrr::map2_dfc( + results, + names(results), + function(res, name) { + lvls <- object$fit$lvl[[name]] # Assuming object$fit$lvl is a named list of levels + if (is.null(lvls)) { + # Fallback if levels are not specifically named for this output + lvls <- paste0("class", 1:ncol(res)) + } + colnames(res) <- lvls + tibble::as_tibble(res, .name_repair = "unique") %>% + dplyr::rename_with(~ paste0(".pred_", name, "_", .x)) } colnames(res) <- lvls tibble::as_tibble(res, .name_repair = "unique") %>% From e8b320ebd55140fd9790760f809d9c39b1ba3e74 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:04:14 +0200 Subject: [PATCH 27/32] Update tests/testthat/test_e2e_func_regression.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/testthat/test_e2e_func_regression.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test_e2e_func_regression.R b/tests/testthat/test_e2e_func_regression.R index 653ca37..d519c90 100644 --- a/tests/testthat/test_e2e_func_regression.R +++ b/tests/testthat/test_e2e_func_regression.R @@ -99,7 +99,7 @@ test_that("E2E: Block repetition works for functional models", { set_engine("keras") fit_3 <- fit(spec_3, mpg ~ ., data = mtcars) model_3_layers <- fit_3 |> - extract_keras_model() |> + extract_keras_model() |> pluck("layers") # Expect 2 layers: Input, Output expect_equal(length(model_3_layers), 2) From ca4a878b8a1c593de1361a6a33f51a36b0d811b5 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:04:20 +0200 Subject: [PATCH 28/32] Update tests/testthat/test_e2e_func_regression.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/testthat/test_e2e_func_regression.R | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_e2e_func_regression.R b/tests/testthat/test_e2e_func_regression.R index d519c90..d2a98dd 100644 --- a/tests/testthat/test_e2e_func_regression.R +++ b/tests/testthat/test_e2e_func_regression.R @@ -111,8 +111,12 @@ test_that("E2E: Multi-input, multi-output functional regression works", { on.exit(options(kerasnip.show_removal_messages = TRUE), add = TRUE) # Define layer blocks - input_block_1 <- function(input_shape) layer_input(shape = input_shape, name = "input_1") - input_block_2 <- function(input_shape) layer_input(shape = input_shape, name = "input_2") + input_block_1 <- function(input_shape) { + layer_input(shape = input_shape, name = "input_1") + } + input_block_2 <- function(input_shape) { + layer_input(shape = input_shape, name = "input_2") + } dense_path <- function(tensor, units = 16) { tensor |> layer_dense(units = units, activation = "relu") } From a97ec1c2a602e623f5e0c979add603e0c46f45e0 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:04:27 +0200 Subject: [PATCH 29/32] Update tests/testthat/test_e2e_func_regression.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/testthat/test_e2e_func_regression.R | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/testthat/test_e2e_func_regression.R b/tests/testthat/test_e2e_func_regression.R index d2a98dd..840d08c 100644 --- a/tests/testthat/test_e2e_func_regression.R +++ b/tests/testthat/test_e2e_func_regression.R @@ -121,8 +121,12 @@ test_that("E2E: Multi-input, multi-output functional regression works", { tensor |> layer_dense(units = units, activation = "relu") } concat_block <- function(in_1, in_2) layer_concatenate(list(in_1, in_2)) - output_block_1 <- function(tensor) layer_dense(tensor, units = 1, name = "output_1") - output_block_2 <- function(tensor) layer_dense(tensor, units = 1, name = "output_2") + output_block_1 <- function(tensor) { + layer_dense(tensor, units = 1, name = "output_1") + } + output_block_2 <- function(tensor) { + layer_dense(tensor, units = 1, name = "output_2") + } model_name <- "multi_in_out_reg" on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE) From 4e98d28e947856046b050a99d8a755aa5f0c98c9 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:04:35 +0200 Subject: [PATCH 30/32] Update tests/testthat/test_e2e_func_regression.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- tests/testthat/test_e2e_func_regression.R | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_e2e_func_regression.R b/tests/testthat/test_e2e_func_regression.R index 840d08c..99d54c1 100644 --- a/tests/testthat/test_e2e_func_regression.R +++ b/tests/testthat/test_e2e_func_regression.R @@ -138,7 +138,10 @@ test_that("E2E: Multi-input, multi-output functional regression works", { input_b = input_block_2, path_a = inp_spec(dense_path, "input_a"), path_b = inp_spec(dense_path, "input_b"), - concatenated = inp_spec(concat_block, c(path_a = "in_1", path_b = "in_2")), + concatenated = inp_spec( + concat_block, + c(path_a = "in_1", path_b = "in_2") + ), output_1 = inp_spec(output_block_1, "concatenated"), output_2 = inp_spec(output_block_2, "concatenated") ), From 938e4f6d1c96d65c0f5d4e05de618cbcae390137 Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:05:04 +0200 Subject: [PATCH 31/32] Update R/build_and_compile_model.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/build_and_compile_model.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index 44ab7ab..fc57e66 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -324,7 +324,9 @@ build_and_compile_functional_model <- function( block_name ]]) && length(y_processed$class_levels[[block_name]]) > 0, - num_classes = if (!is.null(y_processed$class_levels[[block_name]])) { + num_classes = if ( + !is.null(y_processed$class_levels[[block_name]]) + ) { length(y_processed$class_levels[[block_name]]) } else { NULL From c5b8465124e2ae8f3c5c08fa2c64846d502d7a7d Mon Sep 17 00:00:00 2001 From: David Date: Sat, 16 Aug 2025 16:05:13 +0200 Subject: [PATCH 32/32] Update R/build_and_compile_model.R Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- R/build_and_compile_model.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/build_and_compile_model.R b/R/build_and_compile_model.R index fc57e66..4089ddb 100644 --- a/R/build_and_compile_model.R +++ b/R/build_and_compile_model.R @@ -336,7 +336,8 @@ build_and_compile_functional_model <- function( block_hyperparams$num_classes <- current_y_info$num_classes } } - } else { # Single output case + } else { + # Single output case if (is_classification) { block_hyperparams$num_classes <- num_classes }