Skip to content

Commit

Permalink
Merge pull request #41 from dandls/github_paradox
Browse files Browse the repository at this point in the history
GitHub paradox
  • Loading branch information
dandls committed May 14, 2024
2 parents 287a264 + f0d5e7a commit 0ed76f9
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 152 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
- {os: windows-latest, r: 'release'}
- {os: macOS-latest, r: 'release'}
- {os: ubuntu-20.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}
- {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest", check: 'newparadox'}
- {os: ubuntu-20.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest", check: 'newmiesmuschelonly'}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand Down Expand Up @@ -86,7 +88,14 @@ jobs:
Rscript -e "reticulate::conda_create('r-reticulate', packages = c('python==3.8'))"
Rscript -e "remotes::install_local()"
Rscript -e "keras::install_keras(tensorflow = Sys.getenv('TF_VERSION'), extra_packages = c('IPython', 'requests', 'certifi', 'urllib3', 'pandas', 'h5py'))"
- if: ${{ matrix.config.r == 'release' && runner.os == 'Linux' }}
shell: Rscript {0}
run: |
remotes::install_github('mlr-org/miesmuschel')
- if: ${{ matrix.config.r == 'release' && runner.os == 'Linux' && matrix.config.check == 'newparadox' }}
shell: Rscript {0}
run: |
remotes::install_github('mlr-org/paradox')
- name: Check
env:
_R_CHECK_CRAN_INCOMING_REMOTE_: false
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: counterfactuals
Type: Package
Title: Counterfactual Explanations
Version: 0.1.3
Version: 0.1.4
Authors@R: c(
person("Susanne", "Dandl", email = "dandls.datascience@gmail.com", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-4324-4163")),
Expand All @@ -10,7 +10,7 @@ Authors@R: c(
person("Giuseppe","Casalicchio", email = "giuseppe.casalicchio@stat.uni-muenchen.de", role = c("ctb"))
)
Maintainer: Susanne Dandl <dandls.datascience@gmail.com>
Description: Modular and unified R6-based interface for counterfactual explanation methods.
Description: Modular and unified R6-based interface for counterfactual explanation methods.
The following methods are currently implemented: Burghmans et al. (2022) <arXiv:2104.07411>,
Dandl et al. (2020) <doi:10.1007/978-3-030-58112-1_31> and Wexler et al. (2019) <doi:10.1109/TVCG.2019.2934619>.
Optional extensions allow these methods to be applied to a variety of models and use cases.
Expand Down Expand Up @@ -58,7 +58,7 @@ Suggests:
mlr
License: LGPL-3
Encoding: UTF-8
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
VignetteBuilder: R.rsp
Roxygen: list(markdown = TRUE, r6 = TRUE)
Config/testthat/edition: 3
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# counterfactuals 0.1.4

* Compatibility with upcoming 'paradox' release

# counterfactuals 0.1.3
* Resolved bug in `Counterfactuals$evaluate(show_diff = TRUE)` after `$subset_to_valid()` and `$revert_subset_to_valid()` were called.
* Throw errors if `x_nn_correct = TRUE` but no correctly classified observation available.
Expand Down
22 changes: 11 additions & 11 deletions R/make_param_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
#'
#' @description
#' Creates a \link[paradox]{ParamSet} for the columns of `dt`. Depending on the class of a column, a different
#' \link[paradox]{Param} is created:
#' * `double`: \link[paradox]{ParamDbl}
#' * `integer`: \link[paradox]{ParamInt}
#' * `character`: \link[paradox]{ParamFct} (with unique values as levels)
#' * `factor`: \link[paradox]{ParamFct} (with factor levels as levels)
#' \link[paradox]{Domain} is created:
#' * `double`: `p_dbl()`
#' * `integer`: `p_int()`
#' * `character`: `p_fct()` (with unique values as levels)
#' * `factor`: `p_fct()` (with factor levels as levels)
#'
#' @param dt (`data.table()`)\cr
#' The data for the \link[paradox]{ParamSet}.
Expand Down Expand Up @@ -40,19 +40,19 @@ make_param_set = function(dt, lower = NULL, upper = NULL) {

# make param
if (is.double(column)) {
param = ParamDbl$new(col_name, lower = lb, upper = ub)
param = p_dbl(lower = lb, upper = ub)
} else if (is.integer(column)) {
param = ParamInt$new(col_name, lower = lb, upper = ub)
param = p_int(lower = lb, upper = ub)
} else if (is.character(column)) {
param = ParamFct$new(col_name, levels = unique(column))
param = p_fct(levels = unique(column))
} else {
param = ParamFct$new(col_name, levels = levels(column))
param = p_fct(levels = levels(column))
}

param
})

ParamSet$new(param_list)
names(param_list) = names(dt)
do.call(paradox::ps, param_list)
}


111 changes: 55 additions & 56 deletions R/moc_algo.R
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
moc_algo = function(predictor, x_interest, pred_column, target, param_set, lower, upper, sdevs_num_feats,
moc_algo = function(predictor, x_interest, pred_column, target, param_set, lower, upper, sdevs_num_feats,
epsilon, fixed_features, max_changed, mu, termination_crit, n_generations, p_rec, p_rec_gen,
p_mut, p_mut_gen, p_mut_use_orig, k, weights, init_strategy, distance_function, cond_sampler = NULL,
p_mut, p_mut_gen, p_mut_use_orig, k, weights, init_strategy, distance_function, cond_sampler = NULL,
ref_point, quiet = TRUE) {
codomain = ParamSet$new(list(
ParamDbl$new("dist_target", tags = "minimize"),
ParamDbl$new("dist_x_interest", tags = "minimize"),
ParamInt$new("no_changed", tags = "minimize"),
ParamDbl$new("dist_train", tags = "minimize")
))

codomain = ps(
dist_target = p_dbl(tags = "minimize"),
dist_x_interest = p_dbl(tags = "minimize"),
no_changed = p_int(tags = "minimize"),
dist_train = p_dbl(tags = "minimize")
)

fitness_function = make_fitness_function(
predictor, x_interest, pred_column, target, weights, k, fixed_features, param_set, distance_function
)

flex_cols = setdiff(names(x_interest), fixed_features)
if (!is.null(sdevs_num_feats)) {
sdevs_flex_num_feats = sdevs_num_feats[names(sdevs_num_feats) %in% flex_cols]
}

param_set_flex = param_set$clone()
param_set_flex$subset(flex_cols)


param_set_flex = param_set$clone()$subset(flex_cols)

objective = bbotk::ObjectiveRFunDt$new(
fun = fitness_function,
domain = param_set_flex,
fun = fitness_function,
domain = param_set_flex,
codomain = codomain
)

if (n_generations > 0L) {
if (termination_crit == "gens") {
terminator = bbotk::trm("gens", generations = n_generations)
Expand All @@ -40,89 +39,89 @@ moc_algo = function(predictor, x_interest, pred_column, target, param_set, lower
},
include_previous_generations = TRUE,
min_delta = 0.00,
patience = n_generations),
patience = n_generations),
bbotk::trm("gens", generations = 500)
)
)
}
} else {
terminator = bbotk::trm("none")
}

oi = bbotk::OptimInstanceMultiCrit$new(
objective,
objective,
terminator = terminator
)

if (n_generations > 0L) {
# Mutator
if (is.null(cond_sampler)) {
op_m = make_moc_mutator(
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
sdevs = sdevs_flex_num_feats,
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
sdevs = sdevs_flex_num_feats,
p_mut = p_mut,
p_mut_gen = p_mut_gen,
p_mut_gen = p_mut_gen,
p_mut_use_orig = p_mut_use_orig
)
} else {
op_m = make_moc_conditional_mutator(
ps = param_set_flex,
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
max_changed = max_changed,
p_mut = p_mut,
p_mut_gen = p_mut_gen,
p_mut_gen = p_mut_gen,
p_mut_use_orig = p_mut_use_orig,
cond_sampler = cond_sampler
)
}

# Recombinator
op_r = make_moc_recombinator(
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
p_rec = p_rec,
p_rec_gen = p_rec_gen
)

# Selectors
# TODO: Replace this by tournament selection
selobj1 = scl("one", objective = 1L)
op_parent = sel("best", selobj1)

sel_nondom_penalized = ScalorNondomPenalized$new(epsilon)
op_survival = sel("best", sel_nondom_penalized)
op_survival = sel("best", sel_nondom_penalized)

mies_prime_operators(
search_space = oi$search_space,
mutators = list(op_m),
search_space = oi$search_space,
mutators = list(op_m),
recombinators = list(op_r),
selectors = list(op_parent, op_survival)
)
)
}

pop_initializer = make_moc_pop_initializer(
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
init_strategy = init_strategy,
flex_cols = flex_cols,
sdevs = sdevs_flex_num_feats,
lower = lower,
upper = upper,
ps = param_set_flex,
x_interest = x_interest,
max_changed = max_changed,
init_strategy = init_strategy,
flex_cols = flex_cols,
sdevs = sdevs_flex_num_feats,
lower = lower,
upper = upper,
predictor = predictor,
fitness_function = fitness_function,
mu = mu
)

if (quiet) {
quiet(mies_init_population(inst = oi, mu = mu, initializer = pop_initializer))
} else {
mies_init_population(inst = oi, mu = mu, initializer = pop_initializer)
}

if (n_generations > 0L) {
tryCatch({
repeat {
Expand All @@ -132,7 +131,7 @@ moc_algo = function(predictor, x_interest, pred_column, target, param_set, lower
warning = function(w){
if(grepl("no columns to delete or assign RHS to", w$message)){
invokeRestart("muffleWarning")
}
}
})
if (quiet) {
quiet(mies_evaluate_offspring(oi, offspring))
Expand All @@ -145,20 +144,20 @@ moc_algo = function(predictor, x_interest, pred_column, target, param_set, lower
})
}
bbotk::assign_result_default(oi)

# Re-attach fixed features
if (!is.null(fixed_features)) {
oi$result[, (fixed_features) := x_interest[, fixed_features, with = FALSE]]
}

# Transform factor column w.r.t to original data
factor_cols = names(which(sapply(predictor$data$X, is.factor)))
for (factor_col in factor_cols) {
fact_col_pred = predictor$data$X[[factor_col]]
value = factor(oi$result[[factor_col]], levels = levels(fact_col_pred), ordered = is.ordered(fact_col_pred))
oi$result[, (factor_col) := value]
}

int_cols = names(which(sapply(predictor$data$X, is.integer)))
if (length(int_cols) > 0L) {
oi$result[, (int_cols) := lapply(.SD, as.integer), .SDcols = int_cols]
Expand Down

0 comments on commit 0ed76f9

Please sign in to comment.