Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
73b0749
refactoring process to ensure there x and y for sequential and functi…
davidrsch Aug 16, 2025
738c738
Ensuring build and compile sequential uses the correct process
davidrsch Aug 16, 2025
2b29c87
Ensuring build and compile functional uses the correct process and ha…
davidrsch Aug 16, 2025
e362717
Modifying collect compile args to handle multi loss and metrics
davidrsch Aug 16, 2025
6f65be4
Ensuring correct use of new process in corresponding fit engines
davidrsch Aug 16, 2025
9128789
Ensuring evaluate retireve process from fit
davidrsch Aug 16, 2025
ff68345
Ensuring predict uses correct process and handles multi output correctly
davidrsch Aug 16, 2025
b7d1844
Updating documentation
davidrsch Aug 16, 2025
2d43e24
Refactoring and adding tests
davidrsch Aug 16, 2025
1039a9b
Removing sections that exeeds the scope of the package
davidrsch Aug 16, 2025
61b1a13
Modified to show a multi input multi output model
davidrsch Aug 16, 2025
e9cb78e
Update R/build_and_compile_model.R
davidrsch Aug 16, 2025
f2e8ec9
Update R/build_and_compile_model.R
davidrsch Aug 16, 2025
8267e80
Update R/build_and_compile_model.R
davidrsch Aug 16, 2025
01f2612
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
beb2425
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
53514cc
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
d2d1198
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
f5156fe
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
12835e2
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
a001385
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
dff4354
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
fa72299
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
f5ebf56
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
94d70bc
Update R/generic_fit_helpers.R
davidrsch Aug 16, 2025
61f7ecc
Update R/register_fit_predict.R
davidrsch Aug 16, 2025
e8b320e
Update tests/testthat/test_e2e_func_regression.R
davidrsch Aug 16, 2025
ca4a878
Update tests/testthat/test_e2e_func_regression.R
davidrsch Aug 16, 2025
a97ec1c
Update tests/testthat/test_e2e_func_regression.R
davidrsch Aug 16, 2025
4e98d28
Update tests/testthat/test_e2e_func_regression.R
davidrsch Aug 16, 2025
938e4f6
Update R/build_and_compile_model.R
davidrsch Aug 16, 2025
c5b8465
Update R/build_and_compile_model.R
davidrsch Aug 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ export(keras_metrics)
export(keras_optimizers)
export(loss_function_keras)
export(optimizer_function)
export(process_x)
export(process_y)
export(process_x_functional)
export(process_x_sequential)
export(process_y_functional)
export(process_y_sequential)
export(register_keras_loss)
export(register_keras_metric)
export(register_keras_optimizer)
Expand Down
249 changes: 214 additions & 35 deletions R/build_and_compile_model.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
#' Build and Compile a Keras Sequential Model
#'
#' @description
#' This internal helper function constructs and compiles a Keras sequential model
#' based on a list of layer blocks and other parameters. It handles data
#' processing, dynamic architecture construction, and model compilation.
#'
#' @param x A data frame or matrix of predictors.
#' @param y A vector or data frame of outcomes.
#' @param layer_blocks A named list of functions that define the layers of the
#' model. The order of the list determines the order of the layers.
#' @param ... Additional arguments passed to the function, including layer
#' hyperparameters, repetition counts for blocks, and compile/fit arguments.
#'
#' @return A compiled Keras model object.
#' @noRd
build_and_compile_sequential_model <- function(
x,
Expand All @@ -11,16 +26,18 @@ build_and_compile_sequential_model <- function(
verbose <- all_args$verbose %||% 0

# Process x input
x_processed <- process_x(x)
x_processed <- process_x_sequential(x)
x_proc <- x_processed$x_proc
input_shape <- x_processed$input_shape

# Process y input
y_processed <- process_y(y)
y_mat <- y_processed$y_proc
y_processed <- process_y_sequential(y)

# Determine is_classification, class_levels, and num_classes
is_classification <- y_processed$is_classification
class_levels <- y_processed$class_levels
num_classes <- y_processed$num_classes
y_mat <- y_processed$y_proc

# Determine default compile arguments based on mode
default_loss <- if (is_classification) {
Expand Down Expand Up @@ -93,6 +110,7 @@ build_and_compile_sequential_model <- function(
}

# --- 3. Model Compilation ---
# Collect all arguments starting with "compile_" from `...`
compile_args <- collect_compile_args(
all_args,
learn_rate,
Expand All @@ -104,6 +122,24 @@ build_and_compile_sequential_model <- function(
return(model)
}

#' Build and Compile a Keras Functional Model
#'
#' @description
#' This internal helper function constructs and compiles a Keras functional model
#' based on a list of layer blocks and other parameters. It handles data
#' processing, dynamic architecture construction (including multiple inputs and
#' branches), and model compilation.
#'
#' @param x A data frame or matrix of predictors. For multiple inputs, this is
#' often a data frame with list-columns.
#' @param y A vector or data frame of outcomes. Can handle multiple outputs if
#' provided as a data frame with multiple columns.
#' @param layer_blocks A named list of functions that define the building blocks
#' of the model graph. Connections are defined by referencing other block names.
#' @param ... Additional arguments passed to the function, including layer
#' hyperparameters, repetition counts for blocks, and compile/fit arguments.
#'
#' @return A compiled Keras model object.
#' @noRd
build_and_compile_functional_model <- function(
x,
Expand All @@ -117,44 +153,117 @@ build_and_compile_functional_model <- function(
verbose <- all_args$verbose %||% 0

# Process x input
x_processed <- process_x(x)
x_processed <- process_x_functional(x)
x_proc <- x_processed$x_proc
input_shape <- x_processed$input_shape

# Process y input
y_processed <- process_y(y)
y_mat <- y_processed$y_proc
is_classification <- y_processed$is_classification
class_levels <- y_processed$class_levels
num_classes <- y_processed$num_classes
y_processed <- process_y_functional(y)

# Determine default compile arguments based on mode
default_loss <- if (is_classification) {
if (num_classes > 2) {
"categorical_crossentropy"
} else {
"binary_crossentropy"
default_losses <- list()
default_metrics_list <- list()

# Check if y_processed$y_proc is a list (indicating multiple outputs)
if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) {
# Multiple outputs
for (output_name in names(y_processed$y_proc)) {
# We need to determine is_classification and num_classes for each output
# based on the class_levels for that output.
current_class_levels <- y_processed$class_levels[[output_name]]
current_is_classification <- !is.null(current_class_levels) &&
length(current_class_levels) > 0
current_num_classes <- if (current_is_classification) {
length(current_class_levels)
} else {
NULL
}

default_losses[[output_name]] <- if (current_is_classification) {
if (current_num_classes > 2) {
"categorical_crossentropy"
} else {
"binary_crossentropy"
}
} else {
"mean_squared_error"
}
default_metrics_list[[output_name]] <- if (current_is_classification) {
"accuracy"
} else {
"mean_absolute_error"
}
}
} else {
"mean_squared_error"
}
default_metrics <- if (is_classification) {
"accuracy"
} else {
"mean_absolute_error"
# Single output case
# Determine is_classification and num_classes from the top-level class_levels
is_classification <- !is.null(y_processed$class_levels) &&
length(y_processed$class_levels) > 0
num_classes <- if (is_classification) {
length(y_processed$class_levels)
} else {
NULL
}

default_losses <- if (is_classification) {
if (num_classes > 2) {
"categorical_crossentropy"
} else {
"binary_crossentropy"
}
} else {
"mean_squared_error"
}
default_metrics_list <- if (is_classification) {
"accuracy"
} else {
"mean_absolute_error"
}
}

# --- 2. Dynamic Model Architecture Construction (DIFFERENT from sequential) ---
# Create a list to store the output tensors of each block. The names of the
# list elements correspond to the block names.
block_outputs <- list()
# The first block MUST be the input layer and MUST NOT have `input_from`.
first_block_name <- names(layer_blocks)[1]
first_block_fn <- layer_blocks[[first_block_name]]
block_outputs[[first_block_name]] <- first_block_fn(input_shape = input_shape)
model_input_tensors <- list() # To collect all input tensors for keras_model

# Identify and process input layers based on names matching input_shape
# This assumes that if input_shape is a named list, the corresponding
# input blocks in layer_blocks will have matching names.
if (is.list(input_shape) && !is.null(names(input_shape))) {
input_block_names_in_spec <- intersect(
names(layer_blocks),
names(input_shape)
)

if (length(input_block_names_in_spec) != length(input_shape)) {
stop(
"Mismatch between named inputs from process_x and named input blocks in layer_blocks. ",
"Ensure all processed inputs have a corresponding named input block in your model specification."
)
}

for (block_name in input_block_names_in_spec) {
block_fn <- layer_blocks[[block_name]]
current_input_tensor <- block_fn(input_shape = input_shape[[block_name]])
block_outputs[[block_name]] <- current_input_tensor
model_input_tensors[[block_name]] <- current_input_tensor
}
remaining_layer_blocks_names <- names(layer_blocks)[
!(names(layer_blocks) %in% input_block_names_in_spec)
]
} else {
# Single input case (original logic, but now also collecting for model_input_tensors)
first_block_name <- names(layer_blocks)[1]
first_block_fn <- layer_blocks[[first_block_name]]
current_input_tensor <- first_block_fn(input_shape = input_shape)
block_outputs[[first_block_name]] <- current_input_tensor
model_input_tensors[[first_block_name]] <- current_input_tensor
remaining_layer_blocks_names <- names(layer_blocks)[-1]
}

# Iterate through the remaining blocks, connecting and repeating them as needed.
for (block_name in names(layer_blocks)[-1]) {
for (block_name in remaining_layer_blocks_names) {
block_fn <- layer_blocks[[block_name]]
block_fmls <- rlang::fn_fmls(block_fn)
block_fml_names <- names(block_fmls)
Expand Down Expand Up @@ -189,8 +298,50 @@ build_and_compile_functional_model <- function(
)

# Add special engine-supplied arguments if the block can accept them
if (is_classification && "num_classes" %in% block_fml_names) {
block_hyperparams$num_classes <- num_classes
# Add special engine-supplied arguments if the block can accept them
# This is primarily for output layers that might need num_classes
if ("num_classes" %in% block_fml_names) {
# Check if this block is an output block and if it's a classification task
if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) {
# Multi-output case
# Find the corresponding output in y_processed based on block_name
y_names <- names(y_processed$y_proc)
# If there is only one output, and this block is named 'output',
# connect them automatically.
if (length(y_names) == 1 && block_name == "output") {
y_name <- y_names[1]
is_cls <- !is.null(y_processed$class_levels[[y_name]]) &&
length(y_processed$class_levels[[y_name]]) > 0
if (is_cls) {
block_hyperparams$num_classes <- length(y_processed$class_levels[[
y_name
]])
}
} else if (block_name %in% y_names) {
# Standard case: block name matches an output name
current_y_info <- list(
is_classification = !is.null(y_processed$class_levels[[
block_name
]]) &&
length(y_processed$class_levels[[block_name]]) > 0,
num_classes = if (
!is.null(y_processed$class_levels[[block_name]])
) {
length(y_processed$class_levels[[block_name]])
} else {
NULL
}
)
if (current_y_info$is_classification) {
block_hyperparams$num_classes <- current_y_info$num_classes
}
}
} else {
# Single output case
if (is_classification) {
block_hyperparams$num_classes <- num_classes
}
}
}

# --- Get Input Tensors for this block ---
Expand Down Expand Up @@ -232,23 +383,51 @@ build_and_compile_functional_model <- function(
block_outputs[[block_name]] <- current_tensor
}

# The last layer must be named 'output'
output_tensor <- block_outputs[["output"]]
if (is.null(output_tensor)) {
stop("An 'output' block must be defined in layer_blocks.")
# The last layer must be named 'output' or match the names of y_processed outputs
final_output_tensors <- list()

# Check if y_processed$y_proc is a named list, indicating multiple outputs)
if (is.list(y_processed$y_proc) && !is.null(names(y_processed$y_proc))) {
# Multiple outputs
for (output_name in names(y_processed$y_proc)) {
# Iterate over the names of the actual outputs
if (is.null(block_outputs[[output_name]])) {
stop(paste0(
"An output block named '",
output_name,
"' must be defined in layer_blocks for multi-output models."
))
}
final_output_tensors[[output_name]] <- block_outputs[[output_name]]
}
} else {
# Single output case
output_tensor <- block_outputs[["output"]]
if (is.null(output_tensor)) {
stop("An 'output' block must be defined in layer_blocks.")
}
final_output_tensors <- output_tensor
}

# If there's only one input, it shouldn't be a list for keras_model
final_model_inputs <- if (length(model_input_tensors) == 1) {
model_input_tensors[[1]]
} else {
model_input_tensors
}

model <- keras3::keras_model(
inputs = block_outputs[[first_block_name]],
outputs = output_tensor
inputs = final_model_inputs,
outputs = final_output_tensors # This will now be a list if multiple outputs
)

# --- 3. Model Compilation ---
# Collect all arguments starting with "compile_" from `...`
compile_args <- collect_compile_args(
all_args,
learn_rate,
default_loss,
default_metrics
default_losses,
default_metrics_list
)
rlang::exec(keras3::compile, model, !!!compile_args)

Expand Down
Loading
Loading