Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast marginal effect plots (i.e., "poor man's PDPs") #91

Closed
bgreenwell opened this issue May 13, 2019 · 3 comments
Closed

Fast marginal effect plots (i.e., "poor man's PDPs") #91

bgreenwell opened this issue May 13, 2019 · 3 comments

Comments

@bgreenwell
Copy link
Owner

bgreenwell commented May 13, 2019

Rather than averaging over the entire training set, you can fix all other features at their median (numeric features) or most frequent (categorical features) value (if there are no interaction effects, these plots will be parallel to the corresponding PDPs). This is similar in spirit to the excellent plotmo package:

exemplar <- function(object) {
  UseMethod("exemplar")
}


exemplar.data.frame <- function(object) {
  res <- as.data.frame(lapply(object, FUN = function(x) {
    if (is.numeric(x)) {
      stats::median(x, na.rm = TRUE)
    } else {
      names(which.max(table(x, useNA = "no")))
    }
  }))
  # res[] <- mapply(FUN = as, res, sapply(object, FUN = class), SIMPLIFY = FALSE)
  # res
  res <- rbind(object[1L, ] , res)  # trick to copy column classes
  res[-1L, ]
}


#
# Example
#

# Load required packages
library(ggplot2)  # for ggtitle() function
library(pdp)      # for visualizing feature effects

# Fit a random forest to the AMes housing data
set.seed(101)
ames <- AmesHousing::make_ames()
rfo <- ranger::ranger(Sale_Price ~ ., data = ames)

# Marginal-effect plot
system.time(
  p1 <- partial(
    object = rfo, 
    pred.var = "Gr_Liv_Area", 
    train = exemplar(ames), 
    pred.grid = data.frame(
      "Gr_Liv_Area" = seq(from = min(ames$Gr_Liv_Area), to = max(ames$Gr_Liv_Area), length = 100)
    ),
    plot = TRUE,
    plot.engine = "ggplot2"
  ) + ggtitle("Marginal effect plot")
)
#   user  system elapsed 
#  1.972   1.383   3.101 

# Partial dependence plot
system.time(
  p2 <- partial(
    object = rfo, 
    pred.var = "Gr_Liv_Area", 
    grid.resolution = 100,
    plot = TRUE,
    plot.engine = "ggplot2"
  ) + ggtitle("Partial dependence plot")
)
#   user  system elapsed 
# 44.232   2.273   9.572 

# Display plots side by side
grid.arrange(p1, p2, nrow = 1)

image

@DeFilippis
Copy link

DeFilippis commented Jul 1, 2019

Thanks you so much for the fantastic work on this package. It's been an absolute life-saver. Had a couple of questions for you about this:

  1. I notice you computed the system.time. Is there any chance you could display what the time savings is for a margins plot over a partial dependence plot?

  2. I'm a little confused as to the difference between a PDP and a margins plot. It looks like your custom squash function returns a dataframe of equal size to the input dataframe, except where all numeric columns are replaced by the median, and all factor columns are replaced with the mode. It then computes the normal partial procedure on this dataset.

Is the difference between what this does and what partial does is that, in the partial case, all the covariates keep their real values (instead of being fixed at their median), so you're getting predictions averaged over every real value of the other variables, rather than averaged over the median values?

If so, where does the speed-up come from?

  1. Will you be implementing this in the package? Perhaps with a "marginsPlot = TRUE" tag?

@bgreenwell
Copy link
Owner Author

Thanks @DeFilippis, glad you've found the package useful. And I've been meaning to come back to this. Responses to your questions below:

  1. Times added. Note however that I expect the time saving to be quite more dramatic for larger data sets or when computing bivariate plots.

  2. You seem to be correct in the difference between the two types of plots: marginal effect plots look at one variable (or multiple) vs. the response while holding all other features constant (e.g., at their median, etc.) PDPs, on the other hand, look at one variable (or multiple) vs. the response while taking into account the average effect of all the other features. In particular, each point on a PDP is computed as the average predictions obtained from a modified copy of the original training data. In other words, it requires scoring lots of data (albeit independently) and many calls to the prediction function. Marginal plots, on the other hand, require scoring one observation per point on the plot, so much quicker and more efficient (but less accurate than PDPs, especially when strong interactions are present).

The code above is nothing more than a way to trick the partial() function into computing a marginal effect plot. An alternative to pdp, as well as another good reference on the difference between the two types of plots, is the exit plotmo package, which refers to marginal effect plots as a poor man's partial dependence plot.

@DeFilippis
Copy link

Perfect! Thanks so much for the thorough and speedy replies. Really appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants