/
FitnessFunction.R
153 lines (143 loc) · 5.62 KB
/
FitnessFunction.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#' @title FitnessFunction Class
#'
#' @description
#' Implements a fitness function for \pkg{mlr3} as `R6` class `FitnessFunction`. An object of that class
#' contains all relevant informations that are necessary to conduct tuning (`mlr3::Task`, `mlr3::Learner`, `mlr3::Resampling`, `mlr3::Measure`s,
#' `paradox::ParamSet`).
#' After defining a fitness function, we can use it to predict the generalization error of a specific learner configuration
#' defined by it's hyperparameter (using `$eval()`).
#' The `FitnessFunction` class is the basis for further tuning strategies, i.e., grid or random search.
#'
#' @section Usage:
#' ```
#' # Construction
#' ff = FitnessFunction$new(task, learner, resampling, param_set,
#' ctrl = tune_control())
#'
#' # Public members
#' ff$task
#' ff$learner
#' ff$resampling
#' ff$param_set
#' ff$ctrl
#' ff$hooks
#' ff$bmr
#'
#' # Public methods
#' ff$eval(x)
#' ff$eval_vectorized(xts)
#' ff$get_best()
#' ff$run_hooks(id)
#' ```
#'
#' @section Arguments:
#' * `task` (`mlr3::Task`):
#' The task that we want to evaluate.
#' * `learner` (`mlr3::Learner`):
#' The learner that we want to evaluate.
#' * `resampling` (`mlr3::Resampling`):
#' The Resampling method that is used to evaluate the learner.
#' * `param_set` ([paradox::ParamSet]):
#' Parameter set to define the hyperparameter space.
#' * `ctrl` (`list()`):
#' See [tune_control()].
#' * `xt` (`list()`):
#' A specific (transformed) parameter configuration given as named list (e.g. for rpart `list(cp = 0.05, minsplit = 4)`).
#' * `xts` (`list()`):
#' Collection of multiple (transformed) parameter values gained that is, for example, gained from a tuning strategy like grid search (see `?paradox::generate_design_grid`).
#' * `id` (`character(1)`):
#' Identifier of a hook.
#'
#' @section Details:
#' * `$new()` creates a new object of class [FitnessFunction].
#' * `$task` (`mlr3::Task`) the task for which the tuning should be conducted.
#' * `$learner` (`mlr3::Learner`) the algorithm for which the tuning should be conducted.
#' * `$resampling` (`mlr3::Resampling`) strategy to evaluate a parameter setting
#' * `$param_set` (`paradox::ParamSet`) parameter space given to the `Tuner` object to generate parameter values.
#' * `$ctrl` (`list()`) execution control object for tuning (see `?tune_control`).
#' * `$hooks` (`list()`) list of functions that could be executed with `run_hooks()`.
#' * `$bmr` (`mlr3::BenchmarkResult`) object that contains all tuning results as `BenchmarkResult` object (see `?BenchmarkResult`).
#' * `$eval(xt)` evaluates the (transformed) parameter setting `xt` (`list`) for the given learner and resampling.
#' * `$eval_vectorized(xts)` performs resampling for multiple (transformed) parameter settings `xts` (list of lists).
#' * `$get_best()` get best parameter configuration from the `BenchmarkResult` object.
#' * `$run_hooks()` run a function that runs on the whole `FitnessFunction` object.
#'
#' @name FitnessFunction
#' @keywords internal
#' @family FitnessFunction
#' @examples
#' # Object required to define the fitness function:
#' task = mlr3::mlr_tasks$get("iris")
#' learner = mlr3::mlr_learners$get("classif.rpart")
#' resampling = mlr3::mlr_resamplings$get("holdout")
#' measures = mlr3::mlr_measures$mget("classif.mmce")
#' task$measures = measures
#' param_set = paradox::ParamSet$new(params = list(
#' paradox::ParamDbl$new("cp", lower = 0.001, upper = 0.1)))
#'
#' ff = FitnessFunction$new(
#' task = task,
#' learner = learner,
#' resampling = resampling,
#' param_set = param_set
#' )
#'
#' ff$eval(list(cp = 0.05, minsplit = 5))
#' ff$eval(list(cp = 0.01, minsplit = 3))
#' ff$get_best()
NULL
#' @export
FitnessFunction = R6Class("FitnessFunction",
public = list(
task = NULL,
learner = NULL,
resampling = NULL,
param_set = NULL,
ctrl = NULL,
hooks = NULL,
bmr = NULL,
initialize = function(task, learner, resampling, param_set, ctrl = tune_control()) {
self$task = mlr3::assert_task(task)
self$learner = mlr3::assert_learner(learner, task = task)
self$resampling = mlr3::assert_resampling(resampling)
self$param_set = checkmate::assert_class(param_set, "ParamSet")
self$ctrl = checkmate::assert_list(ctrl, names = "unique")
self$hooks = list(update_start = list(), update_end = list())
},
eval = function(xt) {
self$eval_vectorized(list(xt))
},
eval_vectorized = function(xts) {
learners = imap(xts, function(xt, i) {
learner = self$learner$clone()
learner$param_vals = insert_named(learner$param_vals, xt)
learner$id = paste0(learner$id, i)
return(learner)
})
self$run_hooks("update_start")
# bmr = mlr3::benchmark(design = mlr3::expand_grid(task = list(self$task), learner = learners,
# resampling = list(self$resampling), measures = self$measures), ctrl = self$ctrl)
# bmr = mlr3::benchmark(design = data.table::data.table(task = list(self$task), learner = learners,
# resampling = list(self$resampling), measures = self$measures), ctrl = self$ctrl)
bmr = mlr3::benchmark(design = data.table::data.table(task = list(self$task), learner = learners,
resampling = list(self$resampling)), ctrl = self$ctrl)
if (is.null(self$bmr)) {
bmr$data$dob = 1L
self$bmr = bmr
} else {
bmr$data$dob = max(self$bmr$data$dob) + 1L
self$bmr$combine(bmr)
}
self$run_hooks("update_end")
invisible(self)
},
get_best = function() {
self$bmr$get_best(self$task$measures[[1L]])
},
run_hooks = function(id) {
funs = self$hooks[[id]]
for (fun in funs)
do.call(fun, list(ff = self))
}
)
)