Skip to content

Commit

Permalink
feat: added type arg for glm/lm models to use raw coef (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
bradleyboehmke committed Dec 16, 2019
1 parent 87de938 commit dfa9f00
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 78 deletions.
30 changes: 23 additions & 7 deletions R/vi_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@
#' for details.}
#'
#' \item{\code{\link[stats]{lm}}}{In (generalized) linear models, variable
#' importance is based on the absolute value of the corresponding
#' importance is typically based on the absolute value of the corresponding
#' \emph{t}-statistics. For such models, the sign of the original coefficient
#' is also returned.}
#' is also returned. By default, \code{type = "t-stat"} is used; however, if the
#' inputs have been appropriately standardized then the raw coefficients can be
#' used with \code{type = "raw"}.}
#'
#' \item{\code{\link[sparklyr]{ml_feature_importances}}}{The Spark ML
#' library provides standard variable importance for tree-based methods (e.g.,
Expand Down Expand Up @@ -1200,24 +1202,38 @@ vi_model.ml_model_random_forest_classification <- function(object, ...) {
#' @rdname vi_model
#'
#' @export
vi_model.lm <- function(object, ...) {
vi_model.lm <- function(object, type = c("t-stat", "raw"), ...) {

# Determine which type of variable importance to compute
type <- match.arg(type)

# pattern to match based on type
if (type == "t-stat") {
type_pattern <- "^(t|z) value"
} else {
type_pattern <- "Estimate"
}

# Construct model-specific variable importance scores
coefs <- summary(object)$coefficients
if (attr(object$terms, "intercept") == 1) {
coefs <- coefs[-1L, , drop = FALSE]
}
pos <- grep("^(t|z) value", x = colnames(coefs)) # grab pos of z/t stat col
pos <- grep(type_pattern, x = colnames(coefs))
tib <- tibble::tibble(
"Variable" = rownames(coefs),
"Importance" = abs(coefs[, pos]),
"Sign" = ifelse(sign(coefs[, "Estimate"]) == 1, yes = "POS", no = "NEG")
)

# Add variable importance type attribute
label <- colnames(coefs)[pos]
label <- substr(label, start = 1, stop = 1) # strip off t or z
attr(tib, which = "type") <- paste0("|", label, "-statistic|")
if (type == "t-stat") {
label <- colnames(coefs)[pos]
label <- substr(label, start = 1, stop = 1) # strip off t or z
attr(tib, which = "type") <- paste0("|", label, "-statistic|")
} else {
attr(tib, which = "type") <- "|raw coefficients|"
}

# Add "vi" class
class(tib) <- c("vi", class(tib))
Expand Down
6 changes: 4 additions & 2 deletions man/add_sparklines.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 4 additions & 13 deletions man/vi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/vi_ice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions man/vi_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/vi_pdp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 6 additions & 21 deletions man/vi_permute.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 2 additions & 8 deletions man/vint.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 7 additions & 20 deletions man/vip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test_vi_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,25 @@ test_that("`vi_model()` works for \"lm\" objects", {

})

test_that("`type` parameter returns proper values", {

# Skips
skip_on_cran()

X <- scale(mtcars[, -1])
Y <- mtcars$mpg
lm_model <- lm(Y ~ X)

# Run checks
coefs <- summary(lm_model)[["coefficients"]][-1, ]
t_stat <- vi_model(lm_model)
raw <- vi_model(lm_model, type = "raw")

expect_equal(as.vector(abs(coefs[, "t value"])), t_stat$Importance)
expect_equal(as.vector(abs(coefs[, "Estimate"])), raw$Importance)

})


# Package: xgboost -------------------------------------------------------------

Expand Down
Binary file modified tests/testthat/xgboost.model
Binary file not shown.

0 comments on commit dfa9f00

Please sign in to comment.