/
fit.quantile_rf.R
124 lines (116 loc) · 4.45 KB
/
fit.quantile_rf.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#' @name .fit.quantile_rf
#' @title Quantile regression forest for \code{tidyfit}
#' @description Fits a nonlinear quantile regression forest on a 'tidyFit' \code{R6} class. The function can be used with \code{\link{regress}}.
#'
#' @param self a 'tidyFit' R6 class.
#' @param data a data frame, data frame extension (e.g. a tibble), or a lazy data frame (e.g. from dbplyr or dtplyr).
#' @return A fitted 'tidyFit' class model.
#'
#' @details **Hyperparameters:**
#'
#' - ntree (*number of trees*)
#' - mtry (*number of variables randomly sampled at each split*)
#'
#' **Important method arguments (passed to \code{\link{m}})**
#'
#' - \code{tau} (the quantile(s) to be estimated)
#'
#' The function provides a wrapper for \code{quantregForest::quantregForest}. See \code{?quantregForest} for more details.
#' The argument \code{tau} is the chosen quantile (default \code{tau = 0.5}).
#' \code{tau} is passed directly to \code{m('quantile_rf', tau = c(0.1, 0.5, 0.9)} and is not passed to \code{predict} as in the \code{quantregForest::quantregForest} package. This is done to ensure a consistent interface with the quantile regression from \code{quantreg}.
#'
#' **Implementation**
#'
#' *No implementation notes*
#'
#' @author Johann Pfitzinger
#' @references
#' Meinshausen N (2017). _quantregForest: Quantile Regression Forests_. R package version 1.3-7, <https://CRAN.R-project.org/package=quantregForest>.
#'
#' @examples
#' # Load data
#' data <- tidyfit::Factor_Industry_Returns
#' data <- dplyr::filter(data, Industry == "HiTec")
#' data <- dplyr::select(data, -Date, -Industry)
#'
#' # Stand-alone function
#' fit <- m("quantile_rf", Return ~ ., data, tau = 0.5, ntree = 50)
#' fit
#'
#' # Within 'regress' function
#' fit <- regress(data, Return ~ .,
#' m("quantile_rf", tau = c(0.1, 0.5, 0.9), ntree = 50))
#' explain(fit)
#'
#' @seealso \code{\link{.fit.quantile}}, \code{\link{.fit.rf}} and \code{\link{m}} methods
#'
#' @importFrom purrr safely quietly
#' @importFrom methods formalArgs
.fit.quantile_rf <- function(
self,
data = NULL
) {
mf <- stats::model.frame(self$formula, data)
x <- stats::model.matrix(self$formula, mf)
y <- stats::model.response(mf)
self$set_args(tau = 0.5, overwrite = FALSE)
valid_args <- c("ntree", "mtry", "weights", "replace", "classwt", "cutoff",
"strata", "sampsize", "nodesize", "maxnodes", "importance",
"localImp", "nPerm", "proximity", "oob.prox", "norm.votes",
"do.trace", "keep.forest", "corr.bias", "keep.inbag", "nthreads")
ctr <- self$args[names(self$args) %in% valid_args]
ctr$importance <- TRUE
eval_fun_ <- function(...) {
args <- list(...)
do.call(quantregForest::quantregForest, args)
}
eval_fun <- purrr::safely(purrr::quietly(eval_fun_))
res <- do.call(eval_fun,
append(list(x = x, y = y), ctr))
.store_on_self(self, res)
self$estimator <- "quantregForest::quantregForest"
invisible(self)
}
.predict.quantregForest <- function(object, data, self = NULL, training_context = FALSE, ...) {
augmented_data <- dplyr::bind_rows(data, self$data)
response_var <- all.vars(self$formula)[1]
if (response_var %in% colnames(data)) {
truth <- data[, response_var]
} else {
data[, response_var] <- 1
truth <- NULL
}
if (self$mode == "regression") {
if (!training_context) {
tau <- self$args$tau
} else {
tau <- 0.5
}
mf <- stats::model.frame(self$formula, augmented_data)
x <- stats::model.matrix(self$formula, mf)
x <- x[1:nrow(data),]
pred_mat <- predict(object, newdata = x, what = tau)
if (length(tau) == 1) {
pred <- dplyr::tibble(prediction = pred_mat, truth = truth, tau = tau)
} else {
pred_mat <- data.frame(pred_mat)
colnames(pred_mat) <- sort(tau)
pred <- dplyr::tibble(pred_mat) %>%
dplyr::mutate(truth = truth) %>%
tidyr::gather("tau", "prediction", -any_of(c("truth")))
}
pred <- dplyr::mutate(pred, tau = as.numeric(tau))
} else {
}
return(pred)
}
.fitted.quantregForest <- function(object, self = NULL, ...) {
.predict.quantregForest(object, data = self$data, self = self, ...) %>%
dplyr::rename(fitted = "prediction") %>%
dplyr::select(-any_of(c("truth")))
}
.resid.quantregForest <- function(object, self = NULL, ...) {
.predict.quantregForest(object, data = self$data, self = self, ...) %>%
dplyr::mutate(residual = .data$truth - .data$prediction) %>%
dplyr::select(-any_of(c("truth", "prediction")))
}