Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ Imports:
tibble,
purrr,
dplyr,
cli
cli,
recipes
Suggests:
testthat (>= 3.0.0),
modeldata,
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(bake,step_collapse)
S3method(prep,step_collapse)
S3method(print,step_collapse)
export(compile_keras_grid)
export(create_keras_functional_spec)
export(create_keras_sequential_spec)
Expand All @@ -24,6 +27,7 @@ export(register_keras_loss)
export(register_keras_metric)
export(register_keras_optimizer)
export(remove_keras_spec)
export(step_collapse)
importFrom(cli,cli_alert_danger)
importFrom(cli,cli_alert_info)
importFrom(cli,cli_alert_success)
Expand All @@ -37,6 +41,8 @@ importFrom(dplyr,filter)
importFrom(dplyr,select)
importFrom(keras3,to_categorical)
importFrom(parsnip,update_dot_check)
importFrom(recipes,bake)
importFrom(recipes,prep)
importFrom(rlang,arg_match)
importFrom(rlang,dots_list)
importFrom(rlang,enquos)
Expand Down
1 change: 1 addition & 0 deletions R/register_core_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ register_core_model <- function(model_name, mode) {
parsnip::set_model_mode(model_name, mode)
parsnip::set_model_engine(model_name, mode, "keras")
parsnip::set_dependency(model_name, "keras", "keras3")
parsnip::set_dependency(model_name, "keras", "kerasnip")

parsnip::set_encoding(
model = model_name,
Expand Down
141 changes: 141 additions & 0 deletions R/step_collapse.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#' Collapse Predictors into a single list-column
#'
#' `step_collapse()` creates a a *specification* of a recipe step that will
#' convert a group of predictors into a single list-column. This is useful
#' for custom models that need the predictors in a different format.
#'
#' @param recipe A recipe object. The step will be added to the sequence of
#' operations for this recipe.
#' @param ... One or more selector functions to choose which variables are
#' affected by the step. See `[selections()]` for more details. For the `tidy`
#' method, these are not currently used.
#' @param role For model terms created by this step, what analysis role should
#' they be assigned?. By default, the new columns are used as predictors.
#' @param trained A logical to indicate if the quantities for preprocessing
#' have been estimated.
#' @param columns A character string of the selected variable names. This is
#' `NULL` until the step is trained by `[prep.recipe()]`.
#' @param new_col A character string for the name of the new list-column. The
#' default is "predictor_matrix".
#' @param skip A logical. Should the step be skipped when the recipe is
#' baked by `[bake.recipe()]`? While all operations are baked when `prep` is run,
#' skipping when `bake` is run may be other times when it is desirable to
#' skip a processing step.
#' @param id A character string that is unique to this step to identify it.
#'
#' @return An updated version of `recipe` with the new step added to the
#' sequence of existing steps (if any). For the `tidy` method, a tibble with
#' columns `terms` which is the columns that are affected and `value` which is
#' the type of collapse.
#'
#' @examples
#' library(recipes)
#'
#' # 2 predictors
#' dat <- data.frame(
#' x1 = 1:10,
#' x2 = 11:20,
#' y = 1:10
#' )
#'
#' rec <- recipe(y ~ ., data = dat) %>%
#' step_collapse(x1, x2, new_col = "pred") %>%
#' prep()
#'
#' bake(rec, new_data = NULL)
#' @importFrom recipes prep bake
#' @export
step_collapse <- function(
recipe,
...,
role = "predictor",
trained = FALSE,
columns = NULL,
new_col = "predictor_matrix",
skip = FALSE,
id = recipes::rand_id("collapse")
) {
recipes::add_step(
recipe,
step_collapse_new(
terms = enquos(...),
role = role,
trained = trained,
columns = columns,
new_col = new_col,
skip = skip,
id = id
)
)
}

step_collapse_new <- function(
terms,
role,
trained,
columns,
new_col,
skip,
id
) {
recipes::step(
subclass = "collapse",
terms = terms,
role = role,
trained = trained,
columns = columns,
new_col = new_col,
skip = skip,
id = id
)
}

#' @export
prep.step_collapse <- function(x, training, info = NULL, ...) {
col_names <- recipes::recipes_eval_select(x$terms, training, info)

step_collapse_new(
terms = x$terms,
role = x$role,
trained = TRUE,
columns = col_names,
new_col = x$new_col,
skip = x$skip,
id = x$id
)
}

#' @export
bake.step_collapse <- function(object, new_data, ...) {
recipes::check_new_data(object$columns, object, new_data)

rows_list <- apply(
new_data[, object$columns, drop = FALSE],
1,
function(row) matrix(row, nrow = 1),
simplify = FALSE
)

new_data[[object$new_col]] <- rows_list

# drop original predictor columns
new_data <- new_data[, setdiff(names(new_data), object$columns), drop = FALSE]

new_data
}

#' @export
print.step_collapse <- function(x, ...) {
if (is.null(x$columns)) {
cat("Collapse predictors into list-column (unprepped)\\n")
} else {
cat(
"Collapse predictors into list-column:",
paste(x$columns, collapse = ", "),
" -> ",
x$new_col,
"\\n"
)
}
invisible(x)
}
13 changes: 13 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ guides:
contents:
- sequential_model
- functional_api
- workflows_sequential
- workflows_functional

# examples:

Expand Down Expand Up @@ -54,6 +56,12 @@ reference:
- extract_keras_model
- keras_evaluate

- title: "Custom recipe steps"
desc: >
Custom stpes for recipe which uses kerasnip models specifications
contents:
- step_collapse

development:
mode: auto

Expand All @@ -75,6 +83,11 @@ navbar:
href: articles/sequential_model.html
- text: "Functional API"
href: articles/functional_api.html
- text: "Workflows"
- text: "Sequential Model"
href: articles/workflows_sequential.html
- text: "Functional API"
href: articles/workflows_functional.html
github:
icon: fa-github
href: https://github.com/davidrsch/kerasnip
71 changes: 71 additions & 0 deletions man/step_collapse.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading