-
Notifications
You must be signed in to change notification settings - Fork 4
/
light_global_surrogate.R
177 lines (164 loc) · 6.19 KB
/
light_global_surrogate.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#' Global Surrogate Tree
#'
#' Model predictions are modelled by a single decision tree, serving as an easy
#' to interprete surrogate to the original model.
#' As suggested in Molnar (see reference below), the quality of the surrogate
#' tree can be measured by its R-squared. The size of the tree can be modified
#' by passing `...` arguments to [rpart::rpart()].
#'
#' @param x An object of class "flashlight" or "multiflashlight".
#' @param data An optional `data.frame`.
#' @param by An optional vector of column names used to additionally group the results.
#' For each group, a separate tree is grown.
#' @param v Vector of variables used in the surrogate model.
#' Defaults to all variables in `data` except "by", "w" and "y".
#' @param use_linkinv Should retransformation function be applied? Default is `TRUE`.
#' @param n_max Maximum number of data rows to consider to build the tree.
#' @param seed An integer random seed used to select data rows if `n_max` is lower than
#' the number of data rows.
#' @param keep_max_levels Number of levels of categorical and factor variables to keep.
#' Other levels are combined to a level "Other". This prevents [rpart::rpart()] to
#' take too long to split non-numeric variables with many levels.
#' @param ... Arguments passed to [rpart::rpart()], such as `maxdepth`.
#' @returns
#' An object of class "light_global_surrogate" with the following elements:
#' - `data` A tibble with results.
#' - `by` Same as input `by`.
#' @export
#' @references Molnar C. (2019). Interpretable Machine Learning.
#' @examples
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' x <- flashlight(model = fit, label = "lm", data = iris)
#' sur <- light_global_surrogate(x)
#' sur$data$r_squared
#' plot(sur)
#' @seealso [plot.light_global_surrogate()]
light_global_surrogate <- function(x, ...) {
UseMethod("light_global_surrogate")
}
#' @describeIn light_global_surrogate Default method not implemented yet.
#' @export
light_global_surrogate.default <- function(x, ...) {
stop("light_global_surrogate method is only available for objects of class flashlight or multiflashlight.")
}
#' @describeIn light_global_surrogate Surrogate model for a flashlight.
#' @export
light_global_surrogate.flashlight <- function(x, data = x$data, by = x$by,
v = NULL, use_linkinv = TRUE,
n_max = Inf, seed = NULL,
keep_max_levels = 4L, ...) {
stopifnot(
"No data!" = is.data.frame(data) && nrow(data) >= 1L,
"'by' not in 'data'!" = by %in% colnames(data),
"Not all 'v' in 'data'" = v %in% colnames(data),
!any(c("label_", "r_squared", "tree_") %in% by)
)
# Set v and remove 'by' from it
if (is.null(v)) {
v <- setdiff(colnames(data), c(x$y, by, x$w))
} else if (!is.null(by)) {
v <- setdiff(v, by)
}
# Subsample if data is very large
n <- nrow(data)
if (n > n_max) {
if (!is.null(seed)) {
set.seed(seed)
}
data <- data[sample(n, n_max), , drop = FALSE]
}
x <- flashlight(
x, data = data, by = by, linkinv = if (use_linkinv) x$linkinv else function(z) z
)
# Add response of tree model
stopifnot(!("pred_" %in% colnames(data)))
data$pred_ <- stats::predict(x)
# Lump factors with many levels for tree fit
for (vv in v) {
data[[vv]] <- .fct_lump(data[[vv]], keep_max_levels = keep_max_levels)
}
# Fit tree within by group
core_func <- function(X) {
fit <- rpart::rpart(
stats::reformulate(v, "pred_"), data = X,
weights = if (!is.null(x$w)) X[[x$w]],
model = FALSE,
y = FALSE,
xval = 0,
...
)
r2 <- MetricsWeighted::r_squared(X$pred_, stats::predict(fit, X))
stats::setNames(data.frame(r2, I(list(fit))), c("r_squared", "tree_"))
}
res <- Reframe(data, FUN = core_func, .by = by)
# Organize output
res$label_ <- x$label
out <- list(data = res[, c("label_", by, "r_squared", "tree_")], by = by)
add_classes(out, c("light_global_surrogate", "light"))
}
#' @describeIn light_global_surrogate Surrogate model for a multiflashlight.
#' @export
light_global_surrogate.multiflashlight <- function(x, ...) {
light_combine(
lapply(x, light_global_surrogate, ...),
new_class = "light_global_surrogate_multi"
)
}
#' Plot Global Surrogate Trees
#'
#' Use [rpart.plot::rpart.plot()] to visualize trees fitted by
#' [light_global_surrogate()].
#'
#' @param x An object of class "light_global_surrogate".
#' @param type Plot type, see help of [rpart.plot::rpart.plot()]. Default is 5.
#' @param auto_main Automatic plot titles (only if multiple trees are shown).
#' @param mfrow If multiple trees are shown in the same figure:
#' what value of `mfrow` to use in [graphics::par()]?
#' @param ... Further arguments passed to [rpart.plot::rpart.plot()].
#' @returns An object of class "ggplot".
#' @export
#' @seealso [light_global_surrogate()]
plot.light_global_surrogate <- function(x, type = 5, auto_main = TRUE,
mfrow = NULL, ...) {
data <- x$data
multi <- is.light_global_surrogate_multi(x)
ndim <- length(x$by) + multi
if (ndim == 0L) {
rpart.plot::rpart.plot(data$tree_[[1L]], roundint = FALSE, type = type, ...)
} else if (ndim == 1L) {
dim_col <- data[[if (multi) "label_" else x$by[1L]]]
m <- length(dim_col)
if (is.null(mfrow)) {
nr <- floor(sqrt(m))
mfrow <- c(nr, ceiling(m / nr))
}
old_params <- graphics::par(mfrow = mfrow)
on.exit(graphics::par(old_params))
for (i in seq_len(m)) {
rpart.plot::rpart.plot(
data$tree_[[i]],
roundint = FALSE,
type = type,
main = if (auto_main) dim_col[i],
...
)
}
} else {
stop("Either one 'by' variable or a multiflashlight is supported.")
}
}
# Helper function to lump too many levels of a factor to category "Other"
.fct_lump <- function(x, keep_max_levels = 4L, other_name = "Other") {
if (!is.character(x) && !is.factor(x)) {
return(x)
}
if (!is.factor(x)) {
x <- factor(x)
}
m <- nlevels(x)
if (m > keep_max_levels + 1L) {
drop_levels <- names(sort(table(x), decreasing = TRUE))[(keep_max_levels + 1L):m]
levels(x)[match(drop_levels, levels(x))] <- other_name
}
x
}