Skip to content

Commit

Permalink
Merge pull request #60 from mayer79/recode
Browse files Browse the repository at this point in the history
Replace "light_recode" by new option in plot.light_effects()
  • Loading branch information
mayer79 committed May 18, 2023
2 parents 7bfd518 + 17797d1 commit eb05ebd
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 116 deletions.
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ S3method(light_profile,multiflashlight)
S3method(light_profile2d,default)
S3method(light_profile2d,flashlight)
S3method(light_profile2d,multiflashlight)
S3method(light_recode,default)
S3method(light_recode,light)
S3method(light_scatter,default)
S3method(light_scatter,flashlight)
S3method(light_scatter,multiflashlight)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# flashlight 0.9.0.9000

## New functionality

- `plot.light_effects()` has gained an argument `recode_labels` to modify the curve labels.

## Deprecated functionality

- `add_shap()`: Deprecated in favor of {kernelshap} or {fastshap}.
- Consequently, `type = "shap"` in `light_profile()`, `light_importance()`, `light_scatter()`, and `light_profile2d()` is deprecated as well.
- `plot_counts()` is deprecated.
- `light_recode()` is deprecated.
- The option `stats = "quartile"`of `light_effects()` and `light_profile()` is deprecated.
- Column names of resulting data objects cannot be set via `options()` anymore.

## Exported -> internal
Expand Down
9 changes: 9 additions & 0 deletions R/aa_deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ add_shap <- function(...) {
plot_counts <- function(...) {
stop("'plot_counts()' has been deprecated.")
}

#' DEPRECATED
#'
#' @param ... Any input.
#' @returns Error message.
#' @export
light_recode <- function(...) {
stop("'light_recode()' is deprecated.")
}
18 changes: 16 additions & 2 deletions R/light_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#' plot(eff)
#'
#' # PDP and ALE
#' plot(eff, use = c("pd", "ale"))
#' plot(eff, use = c("pd", "ale"), recode_labels = c(ale = "ALE"))
#'
#' # Second model with non-linear Petal.Length effect
#' fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris)
Expand Down Expand Up @@ -201,14 +201,18 @@ light_effects.multiflashlight <- function(x, v, data = NULL, breaks = NULL,
#' [ggplot2::geom_point()] and [ggplot2::geom_line()].
#' @param facet_nrow Number of rows in [ggplot2::facet_wrap()].
#' @param show_points Should points be added to the line (default is `TRUE`).
#' @param recode_labels Named vector of curve labels. The names refer to the usual
#' labels, while the values are the desired labels, e.g.,
#' `c("partial dependence" = PDP", "ale" = "ALE").
#' @param ... Further arguments passed to geoms.
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_effects()], [plot_counts()]
plot.light_effects <- function(x, use = c("response", "predicted", "pd"),
zero_counts = TRUE, size_factor = 1,
facet_scales = "free_x", facet_nrow = 1L,
rotate_x = TRUE, show_points = TRUE, ...) {
rotate_x = TRUE, show_points = TRUE,
recode_labels = NULL, ...) {
# Checks
stopifnot(length(use) >= 1L)
if ("all" %in% use) {
Expand All @@ -230,6 +234,16 @@ plot.light_effects <- function(x, use = c("response", "predicted", "pd"),
data <- dplyr::semi_join(data, x$response, by = c("label_", x$by, x$v))
}

# Optionally change labels of type_
if (!is.null(recode_labels)) {
lab <- levels(data$type_)
if (!all(names(recode_labels) %in% lab)) {
stop("'recode_labels' must be a named vector, see ?plot.light_effects()'")
}
lab[match(names(recode_labels), lab)] <- recode_labels
levels(data$type_) <- lab
}

# Put together the plot
if (n) {
p <- ggplot2::ggplot(data, ggplot2::aes(y = value_, x = .data[[x$v]])) +
Expand Down
57 changes: 5 additions & 52 deletions R/light_recode.R
Original file line number Diff line number Diff line change
@@ -1,55 +1,8 @@
#' Recode Factor Columns
#' Recode Factor Columns - DEPRECATED
#'
#' Recodes factor levels of columns in data slots of an object of class "light".
#'
#' @param x An object of class "light".
#' @param what Column identifier to be recoded, e.g., "type". For backward
#' compatibility, also the option identifier (e.g. "type_name") can be passed.
#' @param levels Current levels/values of `type_name` column (in desired order).
#' @param labels New levels of `type_name` column in same order as `levels`.
#' @param ... Further arguments passed to `factor`.
#' @returns `x` with new factor levels of `type_name` column.
#' @export
#' @examples
#' fit_full <- lm(Sepal.Length ~ ., data = iris)
#' fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris)
#' mod_full <- flashlight(
#' model = fit_full, label = "full", data = iris, y = "Sepal.Length"
#' )
#' mod_part <- flashlight(
#' model = fit_part, label = "part", data = iris, y = "Sepal.Length"
#' )
#' mods <- multiflashlight(list(mod_full, mod_part))
#' eff <- light_effects(mods, v = "Species")
#' eff <- light_recode(
#' eff,
#' what = "type_name",
#' levels = c("response", "predicted", "partial dependence", "ale"),
#' labels = c("Observed", "Fitted", "PD", "ALE")
#' )
#' plot(eff, use = "all")
#' @seealso [plot.light_effects()].
light_recode <- function(x, ...) {
UseMethod("light_recode")
}

#' @describeIn light_recode Default method not implemented yet.
#' @export
light_recode.default <- function(x, ...) {
stop("No default method available yet.")
}

#' @describeIn light_recode Recoding factors in data slots of "light" object.
#' @param ... Deprecated.
#' @returns Deprecated.
#' @export
light_recode.light <- function(x, what, levels, labels, ...) {
if (!is.null(wt <- getOption(paste("flashlight", what, sep = ".")))) {
what <- wt
}
data_slots <- names(x)[vapply(x, FUN = is.data.frame, FUN.VALUE = TRUE)]
for (z in data_slots) {
if (what %in% colnames(x[[z]])) {
x[[z]][[what]] <- factor(x[[z]][[what]], levels = levels, labels = labels, ...)
}
}
x
light_recode <- function(...) {
stop("'light_recode()' is deprecated.")
}
2 changes: 1 addition & 1 deletion man/light_effects.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 10 additions & 48 deletions man/light_recode.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/plot.light_effects.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 0 additions & 11 deletions tests/testthat/tests-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,6 @@ test_that("light_combine works", {
expect_equal(light_combine(ell1), ell1)
})

test_that("light_recode works", {
fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris)
fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length")
eff <- light_effects(fl, v = "Species")
eff <- light_recode(eff, what = "type_",
levels = c("response", "predicted",
"partial dependence", "ale"),
labels = c("Observed", "Fitted", "PD", "ALE"))
expect_equal(as.character(eff$pd$type_[1L]), "PD")
})

test_that("selected 'is' functions work", {
fit <- stats::lm(Sepal.Length ~ Species + 0, data = iris)
fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length")
Expand Down

0 comments on commit eb05ebd

Please sign in to comment.