diff --git a/R/Counterfactuals.R b/R/Counterfactuals.R index e7713bc..df3beed 100644 --- a/R/Counterfactuals.R +++ b/R/Counterfactuals.R @@ -278,6 +278,10 @@ Counterfactuals = R6::R6Class("Counterfactuals", is_numeric_col = sapply(dt, function(x) is.numeric(x)) numeric_cols = names(dt)[is_numeric_col] + if (length(numeric_cols) == 0L) { + stop("Can only consider numeric features for parallel plot, but no numeric features present in data") + } + non_numeric_cols = names(dt)[!is_numeric_col] if (length(non_numeric_cols) > 0L) { dt[, (non_numeric_cols) := NULL] diff --git a/tests/testthat/_snaps/Counterfactuals.md b/tests/testthat/_snaps/Counterfactuals.md index 80a45c2..101b7e4 100644 --- a/tests/testthat/_snaps/Counterfactuals.md +++ b/tests/testthat/_snaps/Counterfactuals.md @@ -6,6 +6,10 @@ Assertion on 'feature_names' failed: Names must be a subset of {'var_num_1','var_num_2','var_fact_1','var_fact_2'}, but has additional elements {'non_in_data1','non_in_data2'}. +# plot_parallel returns error for if there are no numeric features + + Can only consider numeric features for parallel plot, but no numeric features present in data + # methods that require at least one counterfactuals are blocked when no counterfactuals found Assertion on 'self$data' failed: Must have at least 1 rows, but has 0 rows. diff --git a/tests/testthat/test-Counterfactuals.R b/tests/testthat/test-Counterfactuals.R index a324141..e87b3cd 100644 --- a/tests/testthat/test-Counterfactuals.R +++ b/tests/testthat/test-Counterfactuals.R @@ -122,6 +122,21 @@ test_that("plot_parallel returns error for unknown feature names", { expect_snapshot_error(cf$plot_surface(c("non_in_data1", "non_in_data2"))) }) +test_that("plot_parallel returns error for if there are no numeric features", { + set.seed(1234) + df <- data.frame( + "a" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)), + "b" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)), + "c" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)), + "target" = as.factor(sample(c("p", "n"), size = 30, replace = TRUE)) + ) + rf = randomForest(target ~ ., data = df, ntree = 5L) + pred = Predictor$new(rf, data = df, type = "class") + x_interest = head(subset(df, select = -target), n = 1L) + wi = WhatIfClassif$new(pred, n_counterfactuals = 5) + cfactuals = wi$find_counterfactuals(x_interest, desired_class = "p") + expect_snapshot_error(cfactuals$plot_parallel()) +}) # $evaluate() ---------------------------------------------------------------------------------------------------------- test_that("evaluate returns error if measures are not known", {