The gadget R package implements the GADGET algorithm for interpretable machine learning. It recursively partitions the feature space to minimize the heterogeneity of feature effects (e.g., Accumulated Local Effects or Partial Dependence), producing a tree of regions where effects are more stable and easier to interpret. The package integrates with the mlr3 ecosystem.
- Interaction detection: Identifies feature interactions by recursively splitting on heterogeneity of effects.
- ALE and PD support:
AleStrategyfor ALE (computed internally),PdStrategyfor PD/ICE (via precomputed effects from iml or similar). - Visualization: Tree structure plots and regional effect curves (ALE, PD, ICE).
- Extensible design: R6-based strategy pattern; plug in custom effect strategies.
- Performance: Core calculations in C++ (Rcpp/RcppArmadillo).
Install the development version from GitHub:
# install.packages("devtools")
devtools::install_github("mlr-org/gadget")Requires R6, ggplot2, data.table, Rcpp; see DESCRIPTION for details.
| Component | Description |
|---|---|
GadgetTree |
Main entry: $new(), $fit(), $plot(), $plot_tree_structure(), $extract_split_info() |
AleStrategy |
ALE-based trees; pass model to $fit(). ALE computed internally. |
PdStrategy |
PD/ICE trees; pass effect (e.g. from iml::FeatureEffects) to $fit(). |
Fit arguments
- AleStrategy-specific
model(required): fitted mlr3 learner used to compute ALE.n_intervals(optional): number of intervals for ALE grids (default:10).predict_fun(optional): custom prediction function; ifNULL, uses the learner’s default.order_method(optional): how to order categorical split-feature levels before searching over binary splits.
Internally, GADGET builds a pairwise distance matrix between levels (using other features), embeds it into 1D, and uses that 1D order for splitting. Supported methods are:"raw": keep the original factor level order (no reordering)."mds"(default): multi-dimensional scaling on the level-distance matrix, then order levels by the 1D coordinates."pca": PCA on the level-distance matrix, then order levels by the first principal component."random": use a random order of levels (mainly for robustness checks or baselines).
- PdStrategy-specific
effect(required): object of classFeatureEffects(e.g. fromiml::FeatureEffects).
- Shared tree arguments (both strategies)
feature_set(optional): subset of features used to compute effects and search for splits.split_feature(optional): subset of features allowed as splitting variables.impr_par: minimum required improvement in heterogeneity to accept a split.min_node_size: minimum number of observations in each node.n_quantiles: number of candidate split points per numerical feature.
GADGET recursively partitions the feature space. At each node it:
- Computes effect heterogeneity (e.g., variance of ALE derivatives or PD/ICE curves).
- Searches for a split (on a chosen feature) that maximally reduces heterogeneity.
- Splits if the reduction exceeds a threshold (
impr_par) and node size is sufficient.
Splits isolate regions where feature effects are more stable, revealing interaction structure.
This section shows how to use GADGET with PD and ALE on the Bikeshare data.
We first build a PD-based tree (with effects computed via iml), then an ALE-based tree (effects computed internally).
For PD/ICE, compute effects with iml first, then pass them to PdStrategy:
library(gadget)
library(iml)
library(mlr3)
library(mlr3learners)
library(ISLR2)
# 1) Load data and fit a black-box model
data("Bikeshare", package = "ISLR2")
set.seed(123)
bike = Bikeshare[sample(seq_len(nrow(Bikeshare)), 1000), ]
bike$workingday = as.factor(bike$workingday)
bike_data = bike[, c("hr", "temp", "workingday", "bikers")]
names(bike_data)[names(bike_data) == "bikers"] = "target"
task = TaskRegr$new(id = "bike", backend = bike_data, target = "target")
learner = lrn("regr.ranger")
learner$train(task)
# 2) Compute ICE/PD effects with iml
predictor = iml::Predictor$new(
model = learner,
data = bike_data[, c("hr", "temp", "workingday")],
y = bike_data$target
)
effect = iml::FeatureEffects$new(
predictor,
method = "ice",
grid.size = 20
)
# 3) Grow a PD-based GadgetTree on the iml effects
tree = GadgetTree$new(
strategy = PdStrategy$new(),
n_split = 2,
min_node_size = 50
)
tree$fit(
data = bike_data,
target_feature_name = "target",
effect = effect
)
# 4) Visualize tree structure and regional PD/ICE curves
tree$plot_tree_structure() # prints the tree topology (depth, node IDs, split features)
tree$extract_split_info()
tree$plot(
effect = effect,
data = bike_data,
target_feature_name = "target",
features = c("hr", "temp")
)Sample split info (PD + Bikeshare):
| id | depth | n_obs | node_type | split_feature | split_value | node_objective | int_imp | int_imp_parent | split_feature_parent | split_value_parent | objective_value_parent | is_final | time |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 1000 | root | workingday | 0 | 19724312.3 | 0.36 | NA | <NA> | <NA> | NA | FALSE | 0.002 |
| 2 | 2 | 316 | left | temp | 0.45 | 4795126.7 | 0.21 | 0.36 | workingday | 0 | 19724312 | FALSE | 0.002 |
| 3 | 2 | 684 | right | temp | 0.51 | 7816195.0 | 0.33 | 0.36 | workingday | 0 | 19724312 | FALSE | 0.002 |
| 4 | 3 | 148 | left | <NA> | <NA> | 287650.3 | NA | 0.21 | temp | 0.45 | 4795127 | TRUE | NA |
| 5 | 3 | 168 | right | <NA> | <NA> | 286534.8 | NA | 0.21 | temp | 0.45 | 4795127 | TRUE | NA |
| 6 | 3 | 345 | left | <NA> | <NA> | 626584.2 | NA | 0.33 | temp | 0.51 | 7816195 | TRUE | NA |
| 7 | 3 | 339 | right | <NA> | <NA> | 747511.4 | NA | 0.33 | temp | 0.51 | 7816195 | TRUE | NA |
Tree structure and regional PD/ICE plots (root and first split):
library(gadget)
library(mlr3)
library(mlr3learners)
library(ISLR2)
# 1) Load and subsample the Bikeshare data
data("Bikeshare", package = "ISLR2")
set.seed(123)
bike = Bikeshare[sample(seq_len(nrow(Bikeshare)), 1000), ]
bike$workingday = as.factor(bike$workingday)
bike_data = bike[, c("hr", "temp", "workingday", "bikers")]
names(bike_data)[names(bike_data) == "bikers"] = "target"
# 2) Fit a black-box regression model with mlr3
task = TaskRegr$new(id = "bike", backend = bike_data, target = "target")
learner = lrn("regr.ranger")
learner$train(task)
# 3) Grow an ALE-based GadgetTree on top of the model
tree = GadgetTree$new(
strategy = AleStrategy$new(),
n_split = 2,
impr_par = 0.01,
min_node_size = 50
)
tree$fit(
data = bike_data,
target_feature_name = "target",
model = learner,
n_intervals = 10
)
# 4) Inspect the tree structure, splits, and regional ALE plots
tree$plot_tree_structure() # prints the tree topology (depth, node IDs, split features)
tree$extract_split_info()
tree$plot(
data = bike_data,
target_feature_name = "target",
features = c("hr", "temp"),
mean_center = TRUE
)Sample split info (ALE + Bikeshare):
| id | depth | n_obs | node_type | split_feature | split_value | node_objective | int_imp | int_imp_parent | int_imp_hr | int_imp_temp | int_imp_workingday | split_feature_parent | split_value_parent | objective_value_parent | is_final | time |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 1 | 1000 | root | workingday | 0 | 2220499.170 | 0.90 | NA | 0.68 | 0.17 | 1 | <NA> | <NA> | NA | FALSE | 0.006 |
| 2 | 2 | 316 | left | temp | 0.47 | 49879.915 | 0.02 | 0.90 | 0.04 | 0.27 | 0 | workingday | 0 | 2220499.17 | FALSE | 0.004 |
| 3 | 2 | 684 | right | temp | 0.47 | 167776.262 | 0.07 | 0.90 | 0.21 | 0.55 | 0 | workingday | 0 | 2220499.17 | FALSE | 0.004 |
| 4 | 3 | 160 | left | <NA> | <NA> | 2978.705 | NA | 0.02 | NA | NA | NA | temp | 0.47 | 49879.91 | TRUE | NA |
| 5 | 3 | 156 | right | <NA> | <NA> | 2944.015 | NA | 0.02 | NA | NA | NA | temp | 0.47 | 49879.91 | TRUE | NA |
| 6 | 3 | 316 | left | <NA> | <NA> | 6558.701 | NA | 0.07 | NA | NA | NA | temp | 0.47 | 167776.26 | TRUE | NA |
| 7 | 3 | 368 | right | <NA> | <NA> | 16683.457 | NA | 0.07 | NA | NA | NA | temp | 0.47 | 167776.26 | TRUE | NA |
Tree structure and regional ALE plots (root and first split):
The tree$plot() method is flexible and can be used to drill down into specific depths, nodes, and features.
It always returns a nested list of ggplot2 objects indexed by depth and node.
-
Controlling which nodes to plot
# Only depth 1 (root) pl = tree$plot( data = bike_data, target_feature_name = "target", features = c("hr", "temp"), depth = 1 ) # A specific node at depth 2 (e.g., right child) pl = tree$plot( data = bike_data, target_feature_name = "target", features = c("hr", "temp"), depth = 2, node_id = 3 # see node IDs in tree$plot_tree_structure() ) # Inspect or manually print a single ggplot object print(pl[[2]][[1]]) # depth 2, first node
-
Selecting features and centering
# Only plot effects for "hr", without mean-centering pl = tree$plot( data = bike_data, target_feature_name = "target", features = "hr", mean_center = FALSE )
-
Overlaying raw observations
If available for your strategy, you can overlay observed ((x, y)) points on top of the regional curves:
pl = tree$plot( data = bike_data, target_feature_name = "target", features = c("hr", "temp"), mean_center = TRUE, show_point = TRUE # add raw data points )
In practice, a common workflow is:
- Use
tree$plot_tree_structure()andtree$extract_split_info()to identify interesting regions. - Call
tree$plot()withdepth/node_id/featuresto inspect those regions. - Manually inspect or save individual plots with
print(pl[[d]][[k]]).
- In R:
?gadget,?GadgetTree,?AleStrategy,?PdStrategy - Paper draft:
paper/(seepaper/README.md)
Herbinger, J., Wright, M. N., Nagler, T., Bischl, B., and Casalicchio, G. (2024). Decomposing Global Feature Effects Based on Feature Interactions. Journal of Machine Learning Research, 25(23-0699), 1–65. https://jmlr.org/papers/volume25/23-0699/23-0699.pdf
MIT







