Skip to content

Commit

Permalink
Merge pull request #23 from dandls/dev_issue_22
Browse files Browse the repository at this point in the history
Early stop for parallel plot if there are no numeric features. Fixes #22
  • Loading branch information
dandls committed Nov 28, 2022
2 parents 2274420 + 94c776f commit 67f5619
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
4 changes: 4 additions & 0 deletions R/Counterfactuals.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ Counterfactuals = R6::R6Class("Counterfactuals",

is_numeric_col = sapply(dt, function(x) is.numeric(x))
numeric_cols = names(dt)[is_numeric_col]
if (length(numeric_cols) == 0L) {
stop("Can only consider numeric features for parallel plot, but no numeric features present in data")
}

non_numeric_cols = names(dt)[!is_numeric_col]
if (length(non_numeric_cols) > 0L) {
dt[, (non_numeric_cols) := NULL]
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/_snaps/Counterfactuals.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

Assertion on 'feature_names' failed: Names must be a subset of {'var_num_1','var_num_2','var_fact_1','var_fact_2'}, but has additional elements {'non_in_data1','non_in_data2'}.

# plot_parallel returns error for if there are no numeric features

Can only consider numeric features for parallel plot, but no numeric features present in data

# methods that require at least one counterfactuals are blocked when no counterfactuals found

Assertion on 'self$data' failed: Must have at least 1 rows, but has 0 rows.
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-Counterfactuals.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ test_that("plot_parallel returns error for unknown feature names", {
expect_snapshot_error(cf$plot_surface(c("non_in_data1", "non_in_data2")))
})

test_that("plot_parallel returns error for if there are no numeric features", {
set.seed(1234)
df <- data.frame(
"a" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)),
"b" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)),
"c" = as.factor(sample(c("x", "y", "z"), size = 30, replace = TRUE)),
"target" = as.factor(sample(c("p", "n"), size = 30, replace = TRUE))
)
rf = randomForest(target ~ ., data = df, ntree = 5L)
pred = Predictor$new(rf, data = df, type = "class")
x_interest = head(subset(df, select = -target), n = 1L)
wi = WhatIfClassif$new(pred, n_counterfactuals = 5)
cfactuals = wi$find_counterfactuals(x_interest, desired_class = "p")
expect_snapshot_error(cfactuals$plot_parallel())
})

# $evaluate() ----------------------------------------------------------------------------------------------------------
test_that("evaluate returns error if measures are not known", {
Expand Down

0 comments on commit 67f5619

Please sign in to comment.