From 41d9cd16aeca373068c78e4a3593491857b1a5d1 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Mon, 20 Sep 2021 18:07:38 -0400 Subject: [PATCH] workflow methods and updated parsnip methods --- DESCRIPTION | 5 ++- NAMESPACE | 3 ++ R/vi.R | 7 +++- R/vi_model.R | 7 +++- R/vip.R | 6 ++- inst/tinytest/test_pkg_workflows.R | 60 ++++++++++++++++++++++++++++++ 6 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 inst/tinytest/test_pkg_workflows.R diff --git a/DESCRIPTION b/DESCRIPTION index 882a6021..b2a68fbf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -64,7 +64,7 @@ Suggests: neuralnet, NeuralNetTools, nnet, - parsnip, + parsnip (>= 0.1.7), party, partykit, pdp, @@ -78,5 +78,6 @@ Suggests: sparklyr (>= 0.8.0), tinytest, varImp, + workflows (>= 0.2.3), xgboost -RoxygenNote: 7.1.1 +RoxygenNote: 7.1.2 diff --git a/NAMESPACE b/NAMESPACE index ad138732..ff3e8a3d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,7 @@ S3method(vi,Learner) S3method(vi,WrappedModel) S3method(vi,default) S3method(vi,model_fit) +S3method(vi,workflow) S3method(vi_firm,default) S3method(vi_model,C5.0) S3method(vi_model,H2OBinomialModel) @@ -39,11 +40,13 @@ S3method(vi_model,randomForest) S3method(vi_model,ranger) S3method(vi_model,rpart) S3method(vi_model,train) +S3method(vi_model,workflow) S3method(vi_model,xgb.Booster) S3method(vi_permute,default) S3method(vi_shap,default) S3method(vip,default) S3method(vip,model_fit) +S3method(vip,workflow) export("%>%") export("%T>%") export(add_sparklines) diff --git a/R/vi.R b/R/vi.R index 555cf7b9..ea087419 100644 --- a/R/vi.R +++ b/R/vi.R @@ -185,7 +185,12 @@ vi.default <- function( #' #' @export vi.model_fit <- function(object, ...) { # package: parsnip - vi(object$fit, ...) + vi(parsnip::extract_fit_engine(object), ...) +} + +#' @export +vi.workflow <- function(object, ...) { # package: workflows + vi(workflows::extract_fit_engine(object), ...) } diff --git a/R/vi_model.R b/R/vi_model.R index 60309284..ae4c4600 100644 --- a/R/vi_model.R +++ b/R/vi_model.R @@ -727,7 +727,12 @@ vi_model.nnet <- function(object, type = c("olden", "garson"), ...) { #' #' @export vi_model.model_fit <- function(object, ...) { # package: parsnip - vi_model(object$fit, ...) + vi_model(parsnip::extract_fit_engine(object), ...) +} + +#' @export +vi_model.workflow <- function(object, ...) { # package: workflows + vi_model(workflows::extract_fit_engine(object), ...) } diff --git a/R/vip.R b/R/vip.R index 93d91c75..ec83c8b8 100644 --- a/R/vip.R +++ b/R/vip.R @@ -236,6 +236,10 @@ vip.default <- function( #' #' @export vip.model_fit <- function(object, ...) { - vip(object$fit, ...) + vip(parsnip::extract_fit_engine(object), ...) } +#' @export +vip.workflow <- function(object, ...) { + vip(workflows::extract_fit_engine(object), ...) +} diff --git a/inst/tinytest/test_pkg_workflows.R b/inst/tinytest/test_pkg_workflows.R new file mode 100644 index 00000000..76ee28cb --- /dev/null +++ b/inst/tinytest/test_pkg_workflows.R @@ -0,0 +1,60 @@ +# Exits +if (!requireNamespace("parsnip", quietly = TRUE)) { + exit_file("Package parsnip missing") +} +if (!requireNamespace("workflows", quietly = TRUE)) { + exit_file("Package workflows missing") +} + +# Load required packages +suppressMessages({ + library(parsnip) + library(workflows) +}) + +# Generate Friedman benchmark data +friedman1 <- gen_friedman(seed = 101) + +# Fit a linear model +mod <- parsnip::linear_reg() +wflow <- workflow() %>% add_model(mod) %>% add_formula(y ~ .) + +fitted <- generics::fit(wflow, data = friedman1) + +# Compute model-based VI scores +vis <- vi(fitted, scale = TRUE) + +expect_error(vi(wflow), "The workflow does not have a model fit") + +# Expect `vi()` and `vi_model()` to both work +expect_identical( + current = vi(fitted, sort = FALSE), + target = vi_model(fitted) +) + +# Check class +expect_identical(class(vis), target = c("vi", "tbl_df", "tbl", "data.frame")) + +# Check dimensions (should be one row for each feature) +expect_identical(ncol(friedman1) - 1L, target = nrow(vis)) + +# Display VIP +vip(vis, geom = "point") + +# Try permutation importance +set.seed(953) # for reproducibility +p <- vip( + object = fitted, + method = "permute", + train = friedman1, + target = "y", + pred_wrapper = predict, + metric = "rmse", + nsim = 30, + geom = "violin", + jitter = TRUE, + all_permutation = TRUE, + mapping = aes(color = Variable) +) +expect_true(inherits(p, what = "ggplot")) +p # display VIP