Skip to content

Commit

Permalink
Merge pull request #9 from holgstr/higher_order_plots
Browse files Browse the repository at this point in the history
Multivariate Effect Plots for Higher Orders
  • Loading branch information
holgstr committed Jun 4, 2024
2 parents be7c143 + 7be259b commit 5600e61
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 34 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# fmeffects 0.1.3

- `fme` and `ame` automatically extract the name of the model's target variable from the model.
- Added support for `"lm"`-type models, such as `stats::glm` and `mgcv::gam`.
- Added support for `"lm"`-type models, such as `stats::glm` and `mgcv::gam`.
- `fme` now computes NLMs via parallel processing with `future` and displays a progress bar while doing so.
- Multivariate effects for more than two features can be computed and visualized.
- Improved visualizations (especially for larger data sets) via hexagon plots.
- Better error communication with `cli`.
- Feature interactions for categorical features are supported.

# fmeffects 0.1.0

Expand Down
7 changes: 4 additions & 3 deletions R/FME.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ ForwardMarginalEffect = R6::R6Class("ForwardMarginalEffect",
} else {
cli::cli_abort("{.arg features} cannot contain both numeric and categorical features.")
}

if (self$step.type == "numerical") {
if(!checkmate::test_numeric(step.size, min.len = 1)) {
cli::cli_abort("{.arg features} must have numeric step lengths for numeric features.")
Expand Down Expand Up @@ -125,9 +124,9 @@ ForwardMarginalEffect = R6::R6Class("ForwardMarginalEffect",
#' @description
#' Plots results, i.e., FME (and NLMs) for non-extrapolation points, for an `FME` object.
#' @param with.nlm Plots NLMs if computed, defaults to `FALSE`.
#' @param bins Numeric vector giving number of bins in both vertical and horizontal directions.
#' @param bins Numeric vector giving number of bins in both vertical and horizontal directions. Applies only to univariate or bivariate numeric effects.
#' See [ggplot2::stat_summary_hex()] for details.
#' @param binwidth Numeric vector giving bin width in both vertical and horizontal directions. Overrides bins if both set.
#' @param binwidth Numeric vector giving bin width in both vertical and horizontal directions. Overrides bins if both set. Applies only to univariate or bivariate numeric effects.
#' See [ggplot2::stat_summary_hex()] for details.
#' @examples
#' # Compute results:
Expand All @@ -139,6 +138,8 @@ ForwardMarginalEffect = R6::R6Class("ForwardMarginalEffect",
FMEPlotUnivariate$new(self$results, self$predictor$X, self$feature, self$step.size)$plot(with.nlm, bins, binwidth)
} else if (length(self$feature) == 2){
FMEPlotBivariate$new(self$results, self$predictor$X, self$feature, self$step.size)$plot(with.nlm, bins, binwidth)
} else if (length(self$feature) >= 3){
FMEPlotHigherOrder$new(self$results, self$predictor$X, self$feature, self$step.size)$plot(with.nlm)
} else {
stop("Cannot plot effects for more than two numerical features.")
}
Expand Down
135 changes: 108 additions & 27 deletions R/FMEPlot.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ FMEPlot = R6::R6Class("FMEPlot",

initializeSubclass = function(results, data, feature, step.size) {

# Check if results is a data.table with a minimum of one observation
checkmate::assertDataTable(results, min.rows = 1)
if (checkmate::test_true(length(unique(results$fme)) == 1)) {
cli::cli_abort(paste("Cannot plot effects if they all have the same value."))
}

self$feature = feature
self$step.size = step.size
Expand All @@ -32,8 +33,86 @@ FMEPlot = R6::R6Class("FMEPlot",
)


# FMEPlot for Higher-Order Numerical Steps
FMEPlotHigherOrder = R6::R6Class(
"FMEPlotHigherOrder",
inherit = FMEPlot,

public = list(

initialize = function(results, data, feature, step.size) {
private$initializeSubclass(results, data, feature, step.size)
},

plot = function(with.nlm = FALSE) {
df = as.data.frame(self$df)
countmax = max(hist(df$fme,
breaks = seq(min(df$fme),
max(df$fme),
l=min(round(nrow(df))*0.4, 20)+1),
plot = FALSE)$counts)
pfme <- ggplot2::ggplot(df) +
ggplot2::geom_histogram(lwd = 0.3,
linetype = "solid",
colour = "black",
fill = "gray",
show.legend = FALSE,
mapping = ggplot2::aes(x = fme, y = ggplot2::after_stat(count)),
bins = min(round(nrow(df))*0.4, 20),
na.rm = TRUE) +
ggplot2::geom_vline(lwd = 1.2, mapping = ggplot2::aes(xintercept = mean(fme))) +
ggplot2::geom_label(x = mean(df$fme), y = countmax*0.9, label = paste0('AME: ', round(mean(df$fme), 4)), fill = 'white') +
ggplot2::xlab(
paste0("FME (",
paste(
paste(
self$feature, "=", self$step.size), collapse = " | ")
,")")) +
ggplot2::ylab("") +
ggplot2::theme_bw() +
ggplot2::theme(panel.border = ggplot2::element_rect(colour = "black", fill=NA, size=0.7),
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10))

if (with.nlm == FALSE) {
pfme
} else if ("nlm" %in% names(df)) {
df$nlm = sapply(df$nlm, FUN = function(x) {max(x, 0)})
pnlm <- ggplot2::ggplot(df) +
ggplot2::geom_histogram(lwd = 0.3,
linetype = "solid",
colour = "black",
fill = "gray",
show.legend = FALSE,
mapping = ggplot2::aes(x = nlm, y = ggplot2::after_stat(count)),
bins = min(round(nrow(df))*0.4, 20),
na.rm = TRUE) +
ggplot2::geom_vline(lwd = 1.2, mapping = ggplot2::aes(xintercept = mean(nlm))) +
ggplot2::geom_label(x = mean(df$nlm), y = countmax*0.9, label = paste0('ANLM: ', round(mean(df$nlm), 4)), fill = 'white') +
ggplot2::xlab(
paste0("NLM (",
paste(
paste(
self$feature, "=", self$step.size), collapse = " | ")
,")")) +
ggplot2::ylab("") +
ggplot2::theme_bw() +
ggplot2::theme(panel.border = ggplot2::element_rect(colour = "black", fill=NA, size=0.7),
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10))
suppressWarnings(cowplot::plot_grid(pfme, pnlm, ncol = 2, rel_widths = c(0.5, 0.5)))
} else {
stop("Only possible to plot NLM for FME objects with NLM computed.")
}
}
)
)

# FMEPlot for Bivariate Numerical Steps
FMEPlotBivariate = R6::R6Class("FMEPlotBivariate",
FMEPlotBivariate = R6::R6Class(
"FMEPlotBivariate",

inherit = FMEPlot,

Expand Down Expand Up @@ -135,7 +214,8 @@ FMEPlotBivariate = R6::R6Class("FMEPlotBivariate",
)

# FMEPlot for Univariate Numerical Steps
FMEPlotUnivariate = R6::R6Class("FMEPlotUnivariate",
FMEPlotUnivariate = R6::R6Class(
"FMEPlotUnivariate",

inherit = FMEPlot,

Expand Down Expand Up @@ -168,24 +248,24 @@ FMEPlotUnivariate = R6::R6Class("FMEPlotUnivariate",
ggplot2::geom_rug(sides = "b", length = ggplot2::unit(0.015, "npc")) +
ggplot2::geom_smooth(ggplot2::aes(x = x1, y = fme), se = FALSE, fullrange = FALSE, linetype = "dashed", linewidth = 0.7, color = "black") +
ggplot2::annotate("segment",
x = 0.5 * min(df$x1) + 0.5 * max(df$x1) - 0.5 * self$step.size[1],
xend = 0.5 * min(df$x1) + 0.5 * max(df$x1) + 0.5 * self$step.size[1],
y = min.fme - 0.06 * range.fme,
yend = min.fme - 0.06 * range.fme,
colour = 'black', size = 1,
arrow = grid::arrow(length = grid::unit(0.2, "cm")),
lineend = "round", linejoin = "mitre") +
x = 0.5 * min(df$x1) + 0.5 * max(df$x1) - 0.5 * self$step.size[1],
xend = 0.5 * min(df$x1) + 0.5 * max(df$x1) + 0.5 * self$step.size[1],
y = min.fme - 0.06 * range.fme,
yend = min.fme - 0.06 * range.fme,
colour = 'black', size = 1,
arrow = grid::arrow(length = grid::unit(0.2, "cm")),
lineend = "round", linejoin = "mitre") +
ggplot2::geom_hline(lwd = 1.2, mapping = ggplot2::aes(yintercept = mean(fme, na.rm = TRUE))) +
ggplot2::geom_label(x = max.x1 + 0.1 * range.x1, y = mean(df$fme, na.rm = TRUE), label = "AME", size = 3, fill = 'white') +
ggplot2::xlab(self$feature[1]) +
ggplot2::ylab("FME") +
ggplot2::theme_bw() +
ggplot2::theme(panel.border = ggplot2::element_rect(colour = "black", fill=NA, size=0.7),
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10),
legend.title = ggplot2::element_text(color = "black", size = 12),
legend.text = ggplot2::element_text(color = "black", size = 10))
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10),
legend.title = ggplot2::element_text(color = "black", size = 12),
legend.text = ggplot2::element_text(color = "black", size = 10))

if (with.nlm == FALSE) {
pfme
Expand Down Expand Up @@ -232,7 +312,8 @@ FMEPlotUnivariate = R6::R6Class("FMEPlotUnivariate",
)

# FMEPlot for Categorical Steps
FMEPlotCategorical = R6::R6Class("FMEPlotCategorical",
FMEPlotCategorical = R6::R6Class(
"FMEPlotCategorical",

inherit = FMEPlot,

Expand All @@ -252,22 +333,22 @@ FMEPlotCategorical = R6::R6Class("FMEPlotCategorical",
plot = FALSE)$counts)
ggplot2::ggplot(df) +
ggplot2::geom_histogram(lwd = 0.3,
linetype = "solid",
colour = "black",
fill = "gray",
show.legend = FALSE,
mapping = ggplot2::aes(x = fme, y = ggplot2::after_stat(count)),
bins = min(round(nrow(df))*0.4, 20),
na.rm = TRUE) +
linetype = "solid",
colour = "black",
fill = "gray",
show.legend = FALSE,
mapping = ggplot2::aes(x = fme, y = ggplot2::after_stat(count)),
bins = min(round(nrow(df))*0.4, 20),
na.rm = TRUE) +
ggplot2::geom_vline(lwd = 1.2, mapping = ggplot2::aes(xintercept = mean(fme))) +
ggplot2::geom_label(x = mean(df$fme), y = countmax*0.9, label = paste0('AME: ', round(mean(df$fme), 4)), fill = 'white') +
ggplot2::xlab(paste0("FME (category: ", self$step.size, ", feature: ", self$feature, ")")) +
ggplot2::ylab("") +
ggplot2::theme_bw() +
ggplot2::theme(panel.border = ggplot2::element_rect(colour = "black", fill=NA, size=0.7),
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10))
axis.title = ggplot2::element_text(size = 12),
axis.text.x = ggplot2::element_text(colour = "black", size = 10),
axis.text.y = ggplot2::element_text(colour = "black", size = 10))
} else {
stop("Cannot plot NLM because NLM can only be computed for numerical features.")
}
Expand Down
4 changes: 2 additions & 2 deletions man/ForwardMarginalEffect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion vignettes/fmeffects.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ effects2 = fme(model = forest,
ep.method = "envelope")
```

For bivariate effects, we can plot the effects (we cannot for more than two features):
For bivariate effects, we can plot the effects in a way similar to univariate effects (for more than two features, we can plot only the histogram of effects):

```{r, message=FALSE}
plot(effects2)
Expand Down

0 comments on commit 5600e61

Please sign in to comment.