Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ jobs:
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y qpdf ghostscript
sudo apt-get install -y qpdf ghostscript graphviz
if: runner.os == 'macOS'
run: brew install graphviz
if: runner.os == 'Windows'
run: choco install graphviz -y

- uses: r-lib/actions/setup-r-dependencies@v2
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ docs
.httr-oauth
.DS_Store
.quarto
vignettes/*_cache
10 changes: 9 additions & 1 deletion R/generic_functional_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,15 @@ generic_functional_fit <- function(

# --- Get Repetition Count ---
num_repeats_arg <- paste0("num_", block_name)
num_repeats <- all_args[[num_repeats_arg]] %||% 1
num_repeats_val <- all_args[[num_repeats_arg]]

# If num_repeats_val is NULL or zapped, default to 1.
# Otherwise, use the value provided by the user.
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
num_repeats <- 1
} else {
num_repeats <- as.integer(num_repeats_val)
}

# --- Get Hyperparameters for this block ---
# Hyperparameters are formals that are NOT other block names (graph connections)
Expand Down
9 changes: 8 additions & 1 deletion R/generic_sequential_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ generic_sequential_fit <- function(

num_repeats_arg <- paste0("num_", block_name)
num_repeats_val <- all_args[[num_repeats_arg]]
num_repeats <- num_repeats_val %||% 1

# If num_repeats_val is NULL or zapped, default to 1.
# Otherwise, use the value provided by the user.
if (is.null(num_repeats_val) || inherits(num_repeats_val, "rlang_zap")) {
num_repeats <- 1
} else {
num_repeats <- as.integer(num_repeats_val)
}

# Get the arguments for this specific block from `...`
block_arg_names <- names(block_fmls)[-1] # Exclude 'model'
Expand Down
6 changes: 3 additions & 3 deletions R/register_fit_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit$fit),
x = rlang::expr(as.matrix(new_data))
x = rlang::expr(process_x(new_data)$x_proc)
)
)
)
Expand All @@ -74,7 +74,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit$fit),
x = rlang::expr(as.matrix(new_data))
x = rlang::expr(process_x(new_data)$x_proc)
)
)
)
Expand All @@ -89,7 +89,7 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
func = c(fun = "predict"),
args = list(
object = rlang::expr(object$fit$fit),
x = rlang::expr(as.matrix(new_data))
x = rlang::expr(process_x(new_data)$x_proc)
)
)
)
Expand Down
5 changes: 5 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ process_x <- function(x) {
#' @importFrom keras3 to_categorical
#' @noRd
process_y <- function(y, is_classification = NULL, class_levels = NULL) {
# If y is a data frame/tibble, extract the first column
if (is.data.frame(y)) {
y <- y[[1]]
}

if (is.null(is_classification)) {
is_classification <- is.factor(y)
}
Expand Down
6 changes: 3 additions & 3 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ guides:
- title: "Getting Started"
navbar: ~
contents:
- getting-started
- functional-api
- getting_started
- functional_api

# examples:

Expand Down Expand Up @@ -63,7 +63,7 @@ navbar:
components:
intro:
text: "Getting started"
href: guides/getting-started.html
href: guides/getting_started.html
github:
icon: fa-github
href: https://github.com/davidrsch/kerasnip
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ library(modeldata)
library(rsample)
library(dials)
library(tune)
library(purrr)

skip_if_no_keras <- function() {
testthat::skip_if_not_installed("keras3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,19 @@ test_that("E2E: Customizing fit arguments works", {
expect_lt(length(fit_obj$fit$history$metrics$loss), 5)
})

test_that("E2E: Setting num_blocks = 0 works", {
test_that("E2E: Setting num_blocks = 0 works for sequential models", {
skip_if_no_keras()

input_block_zero <- function(model, input_shape) {
keras3::keras_model_sequential(input_shape = input_shape)
}
dense_block_zero <- function(model, units = 16) {
model |> keras3::layer_dense(units = units, activation = "relu")
model |>
keras3::layer_dense(
units = units,
activation = "relu",
name = "i_should_not_exist"
)
}
output_block_zero <- function(model) {
model |> keras3::layer_dense(units = 1)
Expand All @@ -128,10 +133,18 @@ test_that("E2E: Setting num_blocks = 0 works", {
mode = "regression"
)

spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 2) |>
spec <- e2e_mlp_zero(num_dense = 0, fit_epochs = 1) |>
parsnip::set_engine("keras")
# This should fit a model with only an input and output layer
expect_no_error(parsnip::fit(spec, mpg ~ ., data = mtcars))

fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)

# Check that the dense layer is NOT in the model
keras_model <- fit_obj |> extract_keras_summary()
expect_equal(length(keras_model$layers), 1) # Output layers only

# Check layer names explicitly
layer_names <- sapply(keras_model$layers, function(l) l$name)
expect_false("i_should_not_exist" %in% layer_names)
})

test_that("E2E: Error handling for reserved names works", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,57 @@ test_that("E2E: Functional spec tuning (including repetition) works", {
expect_s3_class(metrics, "tbl_df")
expect_true(all(c("num_dense_path", "dense_path_units") %in% names(metrics)))
})

test_that("E2E: Block repetition works for functional models", {
skip_if_no_keras()

input_block <- function(input_shape) keras3::layer_input(shape = input_shape)
dense_block <- function(tensor, units = 8) {
tensor |> keras3::layer_dense(units = units, activation = "relu")
}
output_block <- function(tensor) keras3::layer_dense(tensor, units = 1)

model_name <- "e2e_func_repeat"
on.exit(suppressMessages(remove_keras_spec(model_name)), add = TRUE)

create_keras_functional_spec(
model_name = model_name,
layer_blocks = list(
main_input = input_block,
dense_path = inp_spec(dense_block, "main_input"),
output = inp_spec(output_block, "dense_path")
),
mode = "regression"
)

# --- Test with 1 repetition ---
spec_1 <- e2e_func_repeat(num_dense_path = 1, fit_epochs = 1) |>
set_engine("keras")
fit_1 <- fit(spec_1, mpg ~ ., data = mtcars)
model_1_layers <- fit_1 |>
extract_keras_summary() |>
pluck("layers")

# Expect 3 layers: Input, Dense, Output
expect_equal(length(model_1_layers), 3)

# --- Test with 2 repetitions ---
spec_2 <- e2e_func_repeat(num_dense_path = 2, fit_epochs = 1) |>
set_engine("keras")
fit_2 <- fit(spec_2, mpg ~ ., data = mtcars)
model_2_layers <- fit_2 |>
extract_keras_summary() |>
pluck("layers")
# Expect 4 layers: Input, Dense, Dense, Output
expect_equal(length(model_2_layers), 4)

# --- Test with 0 repetitions ---
spec_3 <- e2e_func_repeat(num_dense_path = 0, fit_epochs = 1) |>
set_engine("keras")
fit_3 <- fit(spec_3, mpg ~ ., data = mtcars)
model_3_layers <- fit_3 |>
extract_keras_summary() |>
pluck("layers")
# Expect 2 layers: Input, Output
expect_equal(length(model_3_layers), 2)
})
File renamed without changes.
File renamed without changes.
Loading
Loading