-
Notifications
You must be signed in to change notification settings - Fork 0
/
complete_tflow.R
162 lines (149 loc) · 5.81 KB
/
complete_tflow.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#' Fit the best model from a tuning grid
#'
#' @param x A tidyflow
#' @param metric The metric of reference from which to pick the best model
#' @param ... Extra arguments passed to
#' \code{\link[tune:show_best]{select_by_one_std_err}} or
#' \code{\link[tune:show_best]{select_by_pct_loss}}
#' @param best_params A 1 row tibble with the best parameters to fit the final
#' model. Should have the same format as the result of
#' \code{\link[tune:show_best]{select_best}},
#' \code{\link[tune:show_best]{select_by_one_std_err}} or
#' \code{\link[tune:show_best]{select_by_pct_loss}}. If \code{best_params} is
#' specified, the \code{method}, \code{metric} and \code{...} arguments are
#' ignored.
#' @param method which method to use. The possible values are
#' \code{\link[tune:show_best]{select_best}},
#' \code{\link[tune:show_best]{select_by_one_std_err}} or
#' \code{\link[tune:show_best]{select_by_pct_loss}}. By default, it uses
#' \code{\link[tune:show_best]{select_best}}.
#'
#' @param control A \code{\link{control_tidyflow}} object. The
#' \code{\link[parsnip]{control_parsnip}} control object inside
#' \code{\link{control_tidyflow}} is passed to
#' \code{\link[generics]{fit}}.
#'
#' @details The finalized model is fitted on the training data if
#' \code{plug_split} was specified otherwise on the complete data.
#'
#' @return The tidyflow `object` updated with the fitted best model. Can be
#' extracted with \code{\link{pull_tflow_fit}} and used to predict on the
#' training or test data with \code{\link{predict_training}} or
#' \code{\link{predict_testing}}
#'
#' @export
#' @examples
#' \dontrun{
#' library(parsnip)
#' library(tune)
#' library(dials)
#' library(rsample)
#'
#' # Fit a regularized regression through a grid search.
#' reg_mod <- set_engine(linear_reg(penalty = tune(), mixture = tune()),
#' "glmnet")
#' tuned_res <-
#' mtcars %>%
#' tidyflow() %>%
#' plug_resample(vfold_cv, v = 2) %>%
#' plug_formula(mpg ~ .) %>%
#' plug_model(reg_mod) %>%
#' plug_grid(grid_regular, levels = 1) %>%
#' fit()
#'
#' # Finalize the best model and refit on the whole dataset
#' final_model <- complete_tflow(tuned_res, metric = "rmse")
#'
#' # complete_tflow uses tune::select_best as the default method. However,
#' # tune::select_by_one_std_err and
#' # tune::select_by_pct_loss can be used. These need to specify the metric and
#' # the tuning value from which to sort the selection. For example:
#' final_model_stderr <- complete_tflow(tuned_res,
#' metric = "rmse",
#' method = "select_by_one_std_err",
#' penalty)
#'
#' # select_by_one_std_err finalizs the best model with the simplest tuning
#' # values within one standard deviation from most optimal
#' # combination. For more information on these methods, see
#' # ?select_best
#'
#' # You can also specify the best parameters, in case you want
#' # to override the automatic extraction of the best fit. If you
#' # specify `best_params` it will override all other arguments
#'
#' best_params <- select_best(pull_tflow_fit_tuning(tuned_res), metric = "rmse")
#' final_model_custom <- complete_tflow(tuned_res, best_params = best_params)
#'
#' # To see the final tuning values, extract the model spec
#' pull_tflow_spec(final_model)
#'
#' # To extract the final fitted model:
#' pull_tflow_fit(final_model)
#'
#' # Since there was no `plug_split`, the final model is fitted
#' # entirely on the data (no training/testing). If you try to predict
#' # on either one, it will not work:
#' final_model %>%
#' predict_training()
#'
#' # Add a split step, fit again and then finalize the model
#' # to predict on the training set
#' tuned_split <-
#' tuned_res %>%
#' replace_grid(grid_regular) %>%
#' plug_split(initial_split) %>%
#' fit()
#'
#' tuned_split %>%
#' complete_tflow(metric = "rmse") %>%
#' predict_training()
#' }
#'
complete_tflow <- function(x,
metric,
...,
best_params = NULL,
method = c("select_best",
"select_by_one_std_err",
"select_by_pct_loss"),
control = control_tidyflow()) {
if (!inherits(x, "tidyflow")) {
stop("`x` should be a tidyflow")
}
if (!(has_fit_tuning(x))) {
abort("The tidyflow must be tuned to be able to complete the final model") #nolintr
}
if (inherits(pull_tflow_fit_tuning(x), "resample_results")) {
abort("`complete_tflow` cannot finalize a model with a resampling result. To finalize a model you need a tuning result. Did you want `plug_grid`?") #nolintr
}
select_fun <- match.arg(method)
if (is.null(best_params)) {
tune_grid <- pull_tflow_fit_tuning(x)
raw_select_fun <- getExportedValue("tune", select_fun)
best_params <- raw_select_fun(x = tune_grid, metric = metric, ...)
}
parsnip::check_final_param(best_params)
mod <- tidyflow::pull_tflow_spec(x)
mod <- tune::finalize_model(mod, best_params)
x$fit$actions$model$spec <- mod
preproc <- tidyflow::pull_tflow_preprocessor(x)
if (has_preprocessor_recipe(x)) {
preproc <- tune::finalize_recipe(preproc, best_params)
x$pre$actions$recipe$recipe_res <- preproc
}
form <- x$fit$actions$formula
if (inherits(preproc, "recipe")) {
add_preprocessor <- workflows::add_recipe
} else {
add_preprocessor <- workflows::add_formula
}
wflow <- add_preprocessor(workflows::workflow(), preproc)
wflow <- workflows::add_model(wflow, mod, formula = form)
ctrl <- workflows::control_workflow(control_parsnip = control$control_parsnip)
wflow_fit <- generics::fit(wflow, data = x$pre$mold, control = ctrl)
x$fit$fit$fit <- workflows::pull_workflow_fit(wflow_fit)
x$fit$fit$wflow <- wflow_fit
x$trained <- TRUE
x
}