Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

export(create_keras_functional_spec)
export(create_keras_sequential_spec)
export(extract_keras_history)
export(extract_keras_summary)
export(generic_functional_fit)
export(generic_sequential_fit)
export(inp_spec)
export(keras_evaluate)
export(keras_losses)
export(keras_metrics)
export(keras_optimizers)
Expand All @@ -14,6 +17,7 @@ export(register_keras_loss)
export(register_keras_metric)
export(register_keras_optimizer)
export(remove_keras_spec)
importFrom(keras3,to_categorical)
importFrom(parsnip,update_dot_check)
importFrom(rlang,arg_match)
importFrom(rlang,dots_list)
Expand Down
41 changes: 21 additions & 20 deletions R/generic_functional_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,31 +83,32 @@ generic_functional_fit <- function(
learn_rate <- all_args$learn_rate %||% 0.01
verbose <- all_args$verbose %||% 0

if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
} else {
x_proc <- as.matrix(x)
}
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
is_classification <- is.factor(y)
if (is_classification) {
class_levels <- levels(y)
num_classes <- length(class_levels)
y_mat <- keras3::to_categorical(
as.numeric(y) - 1,
num_classes = num_classes
)
default_loss <- if (num_classes > 2) {
# Process x input
x_processed <- process_x(x)
x_proc <- x_processed$x_proc
input_shape <- x_processed$input_shape

# Process y input
y_processed <- process_y(y)
y_mat <- y_processed$y_proc
is_classification <- y_processed$is_classification
class_levels <- y_processed$class_levels
num_classes <- y_processed$num_classes

# Determine default compile arguments based on mode
default_loss <- if (is_classification) {
if (num_classes > 2) {
"categorical_crossentropy"
} else {
"binary_crossentropy"
}
default_metrics <- "accuracy"
} else {
class_levels <- NULL
y_mat <- as.matrix(y)
default_loss <- "mean_squared_error"
default_metrics <- "mean_absolute_error"
"mean_squared_error"
}
default_metrics <- if (is_classification) {
"accuracy"
} else {
"mean_absolute_error"
}

# --- 2. Dynamic Model Architecture Construction (DIFFERENT from sequential) ---
Expand Down
43 changes: 18 additions & 25 deletions R/generic_sequential_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,39 +78,32 @@ generic_sequential_fit <- function(
learn_rate <- all_args$learn_rate %||% 0.01
verbose <- all_args$verbose %||% 0

# Handle both standard tabular data (matrix) and list-columns of arrays
# (for images/sequences) that come from recipes.
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
# Assumes a single predictor column containing a list of arrays.
# We stack them into a single higher-dimensional array.
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
} else {
x_proc <- as.matrix(x)
}
# Process x input
x_processed <- process_x(x)
x_proc <- x_processed$x_proc
input_shape <- x_processed$input_shape

# Determine the correct input shape for the Keras model.
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
# Process y input
y_processed <- process_y(y)
y_mat <- y_processed$y_proc
is_classification <- y_processed$is_classification
class_levels <- y_processed$class_levels
num_classes <- y_processed$num_classes

# Determine default compile arguments based on mode
is_classification <- is.factor(y)
if (is_classification) {
class_levels <- levels(y)
num_classes <- length(class_levels)
y_mat <- keras3::to_categorical(
as.numeric(y) - 1,
num_classes = num_classes
)
default_loss <- if (num_classes > 2) {
default_loss <- if (is_classification) {
if (num_classes > 2) {
"categorical_crossentropy"
} else {
"binary_crossentropy"
}
default_metrics <- "accuracy"
} else {
class_levels <- NULL
y_mat <- as.matrix(y)
default_loss <- "mean_squared_error"
default_metrics <- "mean_absolute_error"
"mean_squared_error"
}
default_metrics <- if (is_classification) {
"accuracy"
} else {
"mean_absolute_error"
}

# --- 2. Dynamic Model Architecture Construction ---
Expand Down
60 changes: 60 additions & 0 deletions R/keras_tools.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#' Evaluate a Kerasnip Model
#'
#' This function provides an `kera_evaluate()` method for `model_fit` objects
#' created by `kerasnip`. It preprocesses the data into the format expected by
#' Keras and then calls `keras3::evaluate()` on the underlying model.
#'
#' @param object A `model_fit` object produced by a `kerasnip` specification.
#' @param x A data frame or matrix of predictors.
#' @param y A vector or data frame of outcomes.
#' @param ... Additional arguments passed on to `keras3::evaluate()`.
#'
#' @return A `list` with evaluation results
#'
#' @export
keras_evaluate <- function(object, x, y = NULL, ...) {
# 1. Preprocess predictor data (x)
x_processed <- process_x(x)
x_proc <- x_processed$x_proc

# 2. Preprocess outcome data (y)
y_proc <- NULL
if (!is.null(y)) {
y_processed <- process_y(
y,
is_classification = !is.null(object$fit$lvl),
class_levels = object$fit$lvl
)
y_proc <- y_processed$y_proc
}

# 3. Call the underlying Keras evaluate method
keras_model <- object$fit$fit
keras3::evaluate(keras_model, x = x_proc, y = y_proc, ...)
}

#' Extract Keras Model Summary
#'
#' @description
#' Extracts and returns the summary of a Keras model fitted with `kerasnip`.
#'
#' @param object A `model_fit` object produced by a `kerasnip` specification.
#' @param ... Additional arguments passed on to `keras3::summary()`.
#'
#' @return A character vector, where each element is a line of the model summary.
#' @export
extract_keras_summary <- function(object, ...) {
object$fit$fit
}

#' Extract Keras Training History
#'
#' @description
#' Extracts and returns the training history of a Keras model fitted with `kerasnip`.
#'
#' @param object A `model_fit` object produced by a `kerasnip` specification.
#' @return A `keras_training_history` containing the training history (metrics per epoch).
#' @export
extract_keras_history <- function(object) {
object$fit$history
}
71 changes: 71 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,74 @@ loss_function_keras <- function(values = NULL) {
finalize = NULL
)
}

#' Process Predictor Input for Keras
#'
#' @description
#' Preprocesses predictor data (`x`) into a format suitable for Keras models.
#' Handles both tabular data and list-columns of arrays (e.g., for images).
#'
#' @param x A data frame or matrix of predictors.
#' @return A list containing:
#' - `x_proc`: The processed predictor data (matrix or array).
#' - `input_shape`: The determined input shape for the Keras model.
#' @noRd
process_x <- function(x) {
if (is.data.frame(x) && ncol(x) == 1 && is.list(x[[1]])) {
# Assumes a single predictor column containing a list of arrays.
# We stack them into a single higher-dimensional array.
x_proc <- do.call(abind::abind, c(x[[1]], list(along = 0)))
} else {
x_proc <- as.matrix(x)
}
input_shape <- if (length(dim(x_proc)) > 2) dim(x_proc)[-1] else ncol(x_proc)
list(x_proc = x_proc, input_shape = input_shape)
}

#' Process Outcome Input for Keras
#'
#' @description
#' Preprocesses outcome data (`y`) into a format suitable for Keras models.
#' Handles both regression (numeric) and classification (factor) outcomes,
#' including one-hot encoding for classification.
#'
#' @param y A vector of outcomes.
#' @param is_classification Logical, optional. If `TRUE`, treats `y` as
#' classification. If `FALSE`, treats as regression. If `NULL` (default),
#' it's determined from `is.factor(y)`.
#' @param class_levels Character vector, optional. The factor levels for
#' classification outcomes. If `NULL` (default), determined from `levels(y)`.
#' @return A list containing:
#' - `y_proc`: The processed outcome data (matrix or one-hot encoded array).
#' - `is_classification`: Logical, indicating if `y` was treated as classification.
#' - `num_classes`: Integer, the number of classes for classification, or `NULL`.
#' - `class_levels`: Character vector, the factor levels for classification, or `NULL`.
#' @importFrom keras3 to_categorical
#' @noRd
process_y <- function(y, is_classification = NULL, class_levels = NULL) {
if (is.null(is_classification)) {
is_classification <- is.factor(y)
}

y_proc <- NULL
num_classes <- NULL
if (is_classification) {
if (is.null(class_levels)) {
class_levels <- levels(y)
}
num_classes <- length(class_levels)
y_factored <- factor(y, levels = class_levels)
y_proc <- keras3::to_categorical(
as.numeric(y_factored) - 1,
num_classes = num_classes
)
} else {
y_proc <- as.matrix(y)
}
list(
y_proc = y_proc,
is_classification = is_classification,
num_classes = num_classes,
class_levels = class_levels
)
}
9 changes: 9 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ reference:
- register_keras_optimizer
- keras_objects

- title: "Model Inspection and Evaluation"
desc: >
Functions for summarizing, evaluating, and extracting information
from trained Keras models.
contents:
- extract_keras_history
- extract_keras_summary
- keras_evaluate

development:
mode: auto

Expand Down
17 changes: 17 additions & 0 deletions man/extract_keras_history.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions man/extract_keras_summary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/keras_evaluate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading