-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Description
When calling evaluate()
on a fitted kerasnip
model, evaluation fails because:
-
The Keras model expects a numeric matrix/tensor as input, but the test data from a tibble is passed as a named list of tensors (causing a shape mismatch).
-
kerasnip
does not currently implement or register anevaluate()
S3 method for its model objects, so callingevaluate()
directly on a workflow or parsnip model fails with “no applicable method” errors.
Reprex
library(kerasnip)
library(keras3)
library(tidymodels)
input_block_class <- function(model, input_shape) {
keras3::keras_model_sequential(input_shape = input_shape)
}
dense_block_class <- function(model, units = 16) {
model |>
keras3::layer_dense(units = units, activation = "relu")
}
output_block_class <- function(model, num_classes) {
model |> keras3::layer_dense(units = num_classes, activation = "softmax")
}
create_keras_sequential_spec(
model_name = "e2e_mlp_class",
layer_blocks = list(
input = input_block_class,
dense = dense_block_class,
output = output_block_class
),
mode = "classification"
)
spec <- e2e_mlp_class(
num_dense = 2,
dense_units = 8,
fit_epochs = 2
) |>
set_engine("keras")
multi_data <- iris
rec_multi <- recipe(Species ~ ., data = multi_data)
wf_multi <- workflow(rec_multi, spec)
fit_multi <- fit(wf_multi, data = multi_data)
fit_multi |> evaluate(iris[, 1:4], iris[, 5])
#> Error in UseMethod("evaluate"): no applicable method for 'evaluate' applied to an object of class "workflow"
fit_multi$fit |> evaluate(iris[, 1:4], iris[, 5])
#> Error in UseMethod("evaluate"): no applicable method for 'evaluate' applied to an object of class "c('stage_fit', 'stage')"
fit_multi$fit$fit |> evaluate(iris[, 1:4], iris[, 5])
#> Error in UseMethod("evaluate"): no applicable method for 'evaluate' applied to an object of class "c('_list', 'model_fit')"
fit_multi$fit$fit$fit |> evaluate(iris[, 1:4], iris[, 5])
#> Error in UseMethod("evaluate"): no applicable method for 'evaluate' applied to an object of class "list"
fit_multi$fit$fit$fit$fit |> evaluate(iris[, 1:4], iris[, 5])
#> Exception encountered when calling Sequential.call().
#>
#> �[1mThe structure of `inputs` doesn't match the expected structure.
#> Expected: keras_tensor
#> Received: inputs={'Sepal.Length': 'Tensor(shape=(None,))', 'Sepal.Width': 'Tensor(shape=(None,))', 'Petal.Length': 'Tensor(shape=(None,))', 'Petal.Width': 'Tensor(shape=(None,))'}�[0m
#>
#> Arguments received by Sequential.call():
#> • inputs={'Sepal.Length': 'tf.Tensor(shape=(None,), dtype=float32)', 'Sepal.Width': 'tf.Tensor(shape=(None,), dtype=float32)', 'Petal.Length': 'tf.Tensor(shape=(None,), dtype=float32)', 'Petal.Width': 'tf.Tensor(shape=(None,), dtype=float32)'}
#> • training=False
#> • mask={'Sepal.Length': 'None', 'Sepal.Width': 'None', 'Petal.Length': 'None', 'Petal.Width': 'None'}
#> • kwargs=<class 'inspect._empty'>
Session Info
Please paste the output of sessionInfo() here:
sessionInfo()
#> R version 4.5.1 (2025-06-13 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 10 x64 (build 19045)
#>
#> Matrix products: default
#> LAPACK version 3.12.1
#>
#> locale:
#> [1] LC_COLLATE=English_United Kingdom.utf8
#> [2] LC_CTYPE=English_United Kingdom.utf8
#> [3] LC_MONETARY=English_United Kingdom.utf8
#> [4] LC_NUMERIC=C
#> [5] LC_TIME=English_United Kingdom.utf8
#>
#> time zone: Europe/Madrid
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] yardstick_1.3.2 workflowsets_1.1.0 workflows_1.2.0
#> [4] tune_1.3.0 tidyr_1.3.1 tibble_3.2.1
#> [7] rsample_1.3.0 recipes_1.3.0 purrr_1.0.4
#> [10] parsnip_1.3.1 modeldata_1.4.0 infer_1.0.8
#> [13] ggplot2_3.5.2 dplyr_1.1.4 dials_1.4.0
#> [16] scales_1.3.0 broom_1.0.8 tidymodels_1.3.0
#> [19] keras3_1.4.0 kerasnip_0.0.0.9000
#>
#> loaded via a namespace (and not attached):
#> [1] tidyselect_1.2.1 timeDate_4041.110 tensorflow_2.16.0
#> [4] fastmap_1.2.0 reprex_2.1.1 digest_0.6.37
#> [7] rpart_4.1.24 timechange_0.3.0 lifecycle_1.0.4
#> [10] survival_3.8-3 magrittr_2.0.3 compiler_4.5.1
#> [13] rlang_1.1.6 tools_4.5.1 yaml_2.3.10
#> [16] data.table_1.17.0 knitr_1.50 reticulate_1.42.0
#> [19] DiceDesign_1.10 withr_3.0.2 nnet_7.3-20
#> [22] grid_4.5.1 sparsevctrs_0.3.3 colorspace_2.1-1
#> [25] future_1.40.0 iterators_1.0.14 globals_0.17.0
#> [28] MASS_7.3-65 zeallot_0.2.0 cli_3.6.4
#> [31] rmarkdown_2.29 generics_0.1.3 future.apply_1.11.3
#> [34] tfruns_1.5.3 splines_4.5.1 parallel_4.5.1
#> [37] base64enc_0.1-3 vctrs_0.6.5 hardhat_1.4.1
#> [40] Matrix_1.7-3 jsonlite_2.0.0 listenv_0.9.1
#> [43] foreach_1.5.2 gower_1.0.2 pak_0.9.0
#> [46] glue_1.8.0 parallelly_1.43.0 codetools_0.2-20
#> [49] lubridate_1.9.4 gtable_0.3.6 munsell_0.5.1
#> [52] GPfit_1.0-9 pillar_1.10.2 furrr_0.3.1
#> [55] htmltools_0.5.8.1 ipred_0.9-15 lava_1.8.1
#> [58] R6_2.6.1 lhs_1.2.0 evaluate_1.0.3
#> [61] lattice_0.22-7 png_0.1-8 backports_1.5.0
#> [64] class_7.3-23 Rcpp_1.0.14 dotty_0.1.0
#> [67] prodlim_2024.06.25 whisker_0.4.1 xfun_0.52
#> [70] fs_1.6.6 pkgconfig_2.0.3
Expected Behavior
Calling evaluate()
on a fitted kerasnip
model (or its fitted Keras backend) should return the model’s loss and metrics evaluated on the provided test set.
Actual Behavior
-
evaluate()
fails because no method is registered forkerasnip
orparsnip
model objects. -
Even if accessed at the raw Keras model level, inputs from a tibble do not match the expected tensor format, leading to a shape mismatch error.
Additional Context
This may require:
-
Implementing a dedicated
evaluate()
method for kerasnip objects. -
Ensuring that evaluation inputs are converted to numeric matrices/arrays compatible with Keras before calling
keras::evaluate()
.