Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Generated by roxygen2: do not edit by hand

export(create_keras_spec)
export(generic_keras_fit_impl)
export(create_keras_functional_spec)
export(create_keras_sequential_spec)
export(generic_functional_fit)
export(generic_sequential_fit)
export(inp_spec)
export(keras_losses)
export(keras_metrics)
export(keras_optimizers)
Expand Down
81 changes: 81 additions & 0 deletions R/build_spec_function.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#' Build the Model Specification Function
#'
#' @description
#' This internal helper uses metaprogramming to construct a complete R function
#' that acts as a `parsnip` model specification (e.g., `my_mlp()`).
#'
#' @details
#' The process involves three main steps:
#' 1. **Function Body Construction**: An expression for the function body is
#' created. This body uses `rlang::enquo()` and `rlang::enquos()` to
#' capture all user-provided arguments (both named and via `...`) into a
#' list of quosures. This list is then passed to `parsnip::new_model_spec()`.
#' 2. **Function Signature Construction**: A formal argument list is created
#' from `all_args`, and `...` is added to allow passthrough arguments.
#' `rlang::new_function()` combines the signature and body into a new
#' function object.
#' 3. **Documentation Attachment**: `generate_roxygen_docs()` creates a
#' comprehensive Roxygen comment block as a string, which is then attached
#' to the new function using `comment()`.
#'
#' @param model_name The name of the model specification function to create (e.g., "my_mlp").
#' @param mode The model mode ("regression" or "classification").
#' @param all_args A named list of formal arguments for the new function's
#' signature, as generated by `collect_spec_args()`. The values are typically
#' `rlang::missing_arg()` or `rlang::zap()`.
#' @param parsnip_names A character vector of all argument names that should be
#' captured as quosures and passed to `parsnip::new_model_spec()`.
#' @param layer_blocks The user-provided list of layer block functions. This is
#' passed directly to `generate_roxygen_docs()` to create documentation for
#' block-specific parameters.
#' @param functional A logical indicating if the model is functional
#' (for `create_keras_functional_spec()`) or sequential. This is passed to
#' `generate_roxygen_docs()` to tailor the documentation.
#' @return A new function object with attached Roxygen comments, ready to be
#' placed in the user's environment.
#' @noRd
build_spec_function <- function(
model_name,
mode,
all_args,
parsnip_names,
layer_blocks,
functional = FALSE
) {
quos_exprs <- purrr::map(
parsnip_names,
~ rlang::expr(rlang::enquo(!!rlang::sym(.x)))
)
names(quos_exprs) <- parsnip_names

body <- rlang::expr({
# Capture both explicit args and ... to pass to the fit impl
# Named arguments are captured into a list of quosures.
main_args <- rlang::list2(!!!quos_exprs)
# ... arguments are captured into a separate list of quosures.
dot_args <- rlang::enquos(...)
args <- c(main_args, dot_args)
parsnip::new_model_spec(
!!model_name,
args = args,
eng_args = NULL,
mode = !!mode,
method = NULL,
engine = NULL
)
})

# Add ... to the function signature to capture any other compile arguments
fn_args <- c(all_args, list(... = rlang::missing_arg()))

fn <- rlang::new_function(args = fn_args, body = body)

docs <- generate_roxygen_docs(
model_name,
layer_blocks,
all_args,
functional = functional
)
comment(fn) <- docs
fn
}
141 changes: 141 additions & 0 deletions R/create_keras_functional_spec.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#' Create a Custom Keras Functional API Model Specification for Tidymodels
#'
#' This function acts as a factory to generate a new `parsnip` model
#' specification based on user-defined blocks of Keras layers using the
#' Functional API. This allows for creating complex, tunable architectures
#' with non-linear topologies that integrate seamlessly with the `tidymodels`
#' ecosystem.
#'
#' @param model_name A character string for the name of the new model
#' specification function (e.g., "custom_resnet"). This should be a valid R
#' function name.
#' @param layer_blocks A named list of functions where each function defines a
#' "block" (a node) in the model graph. The list names are crucial as they
#' define the names of the nodes. The arguments of each function define how
#' the nodes are connected. See the "Model Graph Connectivity" section for
#' details.
#' @param mode A character string, either "regression" or "classification".
#' @param ... Reserved for future use. Currently not used.
#' @param env The environment in which to create the new model specification
#' function and its associated `update()` method. Defaults to the calling
#' environment (`parent.frame()`).
#'
#' @details
#' This function generates all the boilerplate needed to create a custom,
#' tunable `parsnip` model specification that uses the Keras Functional API.
#' This is ideal for models with complex, non-linear topologies, such as
#' networks with multiple inputs/outputs or residual connections.
#'
#' The function inspects the arguments of your `layer_blocks` functions and
#' makes them available as tunable parameters in the generated model
#' specification, prefixed with the block's name (e.g., `dense_units`).
#' Common training parameters such as `epochs` and `learn_rate` are also added.
#'
#' @section Model Graph Connectivity:
#' `kerasnip` builds the model's directed acyclic graph by inspecting the
#' arguments of each function in the `layer_blocks` list. The connection logic
#' is as follows:
#'
#' 1. The **names of the elements** in the `layer_blocks` list define the names
#' of the nodes in your graph (e.g., `main_input`, `dense_path`, `output`).
#' 2. The **names of the arguments** in each block function specify its inputs.
#' A block function like `my_block <- function(input_a, input_b, ...)`
#' declares that it needs input from the nodes named `input_a` and `input_b`.
#' `kerasnip` will automatically supply the output tensors from those nodes
#' when calling `my_block`.
#'
#' There are two special requirements:
#' * **Input Block**: The first block in the list is treated as the input
#' node. Its function should not take other blocks as input, but it can have
#' an `input_shape` argument, which is supplied automatically during fitting.
#' * **Output Block**: Exactly one block must be named `"output"`. The tensor
#' returned by this block is used as the final output of the Keras model.
#'
#' A key feature is the automatic creation of `num_{block_name}` arguments
#' (e.g., `num_dense_path`). This allows you to control how many times a block
#' is repeated, making it easy to tune the depth of your network. A block can
#' only be repeated if it has exactly one input from another block in the graph.
#'
#' The new model specification function and its `update()` method are created
#' in the environment specified by the `env` argument.
#'
#' @importFrom rlang enquos dots_list arg_match env_poke
#' @importFrom parsnip update_dot_check
#'
#' @return Invisibly returns `NULL`. Its primary side effect is to create a
#' new model specification function (e.g., `custom_resnet()`) in the
#' specified environment and register the model with `parsnip` so it can be
#' used within the `tidymodels` framework.
#'
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()],
#' [create_keras_sequential_spec()]
#'
#' @export
#' @examples
#' \dontrun{
#' if (requireNamespace("keras3", quietly = TRUE)) {
#' library(keras3)
#' library(parsnip)
#'
#' # 1. Define block functions. These are the building blocks of our model.
#' # An input block that receives the data's shape automatically.
#' input_block <- function(input_shape) layer_input(shape = input_shape)
#'
#' # A dense block with a tunable `units` parameter.
#' dense_block <- function(tensor, units) {
#' tensor |> layer_dense(units = units, activation = "relu")
#' }
#'
#' # A block that adds two tensors together (for the residual connection).
#' add_block <- function(input_a, input_b) layer_add(list(input_a, input_b))
#'
#' # An output block for regression.
#' output_block_reg <- function(tensor) layer_dense(tensor, units = 1)
#'
#' # 2. Create the spec. The `layer_blocks` list defines the graph.
#' create_keras_functional_spec(
#' model_name = "my_resnet_spec",
#' layer_blocks = list(
#' # The names of list elements are the node names.
#' main_input = input_block,
#'
#' # The argument `main_input` connects this block to the input node.
#' dense_path = function(main_input, units = 32) dense_block(main_input, units),
#'
#' # This block's arguments connect it to the original input AND the dense layer.
#' add_residual = function(main_input, dense_path) add_block(main_input, dense_path),
#'
#' # This block must be named 'output'. It connects to the residual add layer.
#' output = function(add_residual) output_block_reg(add_residual)
#' ),
#' mode = "regression"
#' )
#'
#' # 3. Use the newly created specification function!
#' # The `dense_path_units` argument was created automatically.
#' model_spec <- my_resnet_spec(dense_path_units = 64, epochs = 10)
#'
#' # You could also tune the number of dense layers since it has a single input:
#' # model_spec <- my_resnet_spec(num_dense_path = 2, dense_path_units = 32)
#'
#' print(model_spec)
#' # tune::tunable(model_spec)
#' }
#' }
create_keras_functional_spec <- function(
model_name,
layer_blocks,
mode = c("regression", "classification"),
...,
env = parent.frame()
) {
mode <- rlang::arg_match(mode)
# 1. Argument Validation
create_keras_spec_impl(
model_name,
layer_blocks,
mode,
functional = TRUE,
env
)
}
93 changes: 46 additions & 47 deletions R/create_keras_spec.R → R/create_keras_sequential_spec.R
Original file line number Diff line number Diff line change
@@ -1,55 +1,63 @@
#' Create a Custom Keras Model Specification for Tidymodels
#' Create a Custom Keras Sequential Model Specification for Tidymodels
#'
#' @description
#' This function acts as a factory to generate a new `parsnip` model
#' specification based on user-defined blocks of Keras layers. This allows for
#' creating complex, tunable architectures that integrate seamlessly with the
#' `tidymodels` ecosystem.
#' specification based on user-defined blocks of Keras layers using the
#' Sequential API. This is the ideal choice for creating models that are a
#' simple, linear stack of layers. For models with complex, non-linear
#' topologies, see [create_keras_functional_spec()].
#'
#' @param model_name A character string for the name of the new model
#' specification function (e.g., "custom_cnn"). This should be a valid R
#' function name.
#' @param layer_blocks A named list of functions. Each function defines a "block"
#' of Keras layers. The function must take a Keras model object as its first
#' argument and return the modified model. Other arguments to the function
#' will become tunable parameters in the final model specification.
#' @param layer_blocks A named, ordered list of functions. Each function defines
#' a "block" of Keras layers. The function must take a Keras model object as
#' its first argument and return the modified model. Other arguments to the
#' function will become tunable parameters in the final model specification.
#' @param mode A character string, either "regression" or "classification".
#' @param ... Reserved for future use. Currently not used.
#' @param env The environment in which to create the new model specification
#' function and its associated `update()` method. Defaults to the calling
#' environment (`parent.frame()`).
#' @importFrom rlang enquos dots_list arg_match env_poke
#' @importFrom parsnip update_dot_check
#'
#' @details
#' The user is responsible for defining the entire model architecture by providing
#' an ordered list of layer block functions.
#' 1. The first block function must initialize the model (e.g., with
#' \code{keras_model_sequential()}). It can accept an \code{input_shape} argument,
#' which will be provided automatically by the fitting engine.
#' 2. Subsequent blocks add hidden layers.
#' 3. The final block should add the output layer. For classification, it can
#' accept a \code{num_classes} argument, which is provided automatically.
#'
#' The \code{create_keras_spec()} function will inspect the arguments of your
#' \code{layer_blocks} functions (ignoring \code{input_shape} and \code{num_classes})
#' and make them available as arguments in the generated model specification,
#' prefixed with the block's name (e.g.,
#' `dense_units`).
#'
#' It also automatically creates arguments like `num_dense` to control how many
#' times each block is repeated. In addition, common training parameters such as
#' `epochs`, `learn_rate`, `validation_split`, and `verbose` are added to the
#' specification.
#' This function generates all the boilerplate needed to create a custom,
#' tunable `parsnip` model specification that uses the Keras Sequential API.
#'
#' The function inspects the arguments of your `layer_blocks` functions
#' (ignoring special arguments like `input_shape` and `num_classes`)
#' and makes them available as arguments in the generated model specification,
#' prefixed with the block's name (e.g., `dense_units`).
#'
#' The new model specification function and its `update()` method are created in
#' the environment specified by the `env` argument.
#'
#' @section Model Architecture (Sequential API):
#' `kerasnip` builds the model by applying the functions in `layer_blocks` in
#' the order they are provided. Each function receives the Keras model built by
#' the previous function and returns a modified version.
#'
#' 1. The **first block** must initialize the model (e.g., with
#' `keras_model_sequential()`). It can accept an `input_shape` argument,
#' which `kerasnip` will provide automatically during fitting.
#' 2. **Subsequent blocks** add layers to the model.
#' 3. The **final block** should add the output layer. For classification, it
#' can accept a `num_classes` argument, which is provided automatically.
#'
#' A key feature of this function is the automatic creation of `num_{block_name}`
#' arguments (e.g., `num_hidden`). This allows you to control how many times
#' each block is repeated, making it easy to tune the depth of your network.
#'
#' @importFrom rlang enquos dots_list arg_match env_poke
#' @importFrom parsnip update_dot_check
#'
#' @return Invisibly returns `NULL`. Its primary side effect is to create a new
#' model specification function (e.g., `dynamic_mlp()`) in the specified
#' model specification function (e.g., `my_mlp()`) in the specified
#' environment and register the model with `parsnip` so it can be used within
#' the `tidymodels` framework.
#'
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()]
#' @seealso [remove_keras_spec()], [parsnip::new_model_spec()],
#' [create_keras_functional_spec()]
#'
#' @export
#' @examples
Expand All @@ -75,7 +83,7 @@
#' }
#'
#' # 2. Create the spec, providing blocks in the correct order.
#' create_keras_spec(
#' create_keras_sequential_spec(
#' model_name = "my_mlp",
#' layer_blocks = list(
#' input = input_block,
Expand All @@ -86,7 +94,7 @@
#' )
#'
#' # 3. Use the newly created specification function!
# Note the new arguments `num_hidden` and `hidden_units`.
#' # Note the new arguments `num_hidden` and `hidden_units`.
#' model_spec <- my_mlp(
#' num_hidden = 2,
#' hidden_units = 64,
Expand All @@ -97,28 +105,19 @@
#' print(model_spec)
#' }
#' }
create_keras_spec <- function(
create_keras_sequential_spec <- function(
model_name,
layer_blocks,
mode = c("regression", "classification"),
...,
env = parent.frame()
) {
mode <- arg_match(mode)
args_info <- collect_spec_args(layer_blocks)
spec_fun <- build_spec_function(
create_keras_spec_impl(
model_name,
layer_blocks,
mode,
args_info$all_args,
args_info$parsnip_names,
layer_blocks
functional = FALSE,
env
)

register_core_model(model_name, mode)
register_model_args(model_name, args_info$parsnip_names)
register_fit_predict(model_name, mode, layer_blocks)
register_update_method(model_name, args_info$parsnip_names, env = env)

env_poke(env, model_name, spec_fun)
invisible(NULL)
}
Loading
Loading