/
FilterImportance.R
75 lines (68 loc) · 2.45 KB
/
FilterImportance.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#' @title Filter for Embedded Feature Selection via Variable Importance
#'
#' @name mlr_filters_importance
#'
#' @description Variable Importance filter using embedded feature selection of
#' machine learning algorithms. Takes a [mlr3::Learner] which is capable of
#' extracting the variable importance (property "importance"), fits the model
#' and extracts the importance values to use as filter scores.
#'
#' @family Filter
#' @template seealso_filter
#' @export
#' @examples
#' task = mlr3::tsk("iris")
#' learner = mlr3::lrn("classif.rpart")
#' filter = flt("importance", learner = learner)
#' filter$calculate(task)
#' as.data.table(filter)
FilterImportance = R6Class("FilterImportance", inherit = Filter,
public = list(
#' @field learner ([mlr3::Learner])\cr
#' Learner to extract the importance values from.
learner = NULL,
#' @description Create a FilterImportance object.
#' @param id (`character(1)`)\cr
#' Identifier for the filter.
#' @param task_type (`character()`)\cr
#' Types of the task the filter can operator on. E.g., `"classif"` or
#' `"regr"`.
#' @param param_set ([paradox::ParamSet])\cr
#' Set of hyperparameters.
#' @param feature_types (`character()`)\cr
#' Feature types the filter operates on.
#' Must be a subset of
#' [`mlr_reflections$task_feature_types`][mlr3::mlr_reflections].
#' @param learner ([mlr3::Learner])\cr
#' Learner to extract the importance values from.
#' @param packages (`character()`)\cr
#' Set of required packages.
#' Note that these packages will be loaded via [requireNamespace()], and
#' are not attached.
initialize = function(id = "importance",
task_type = learner$task_type,
feature_types = learner$feature_types,
learner = mlr3::lrn("classif.rpart"),
packages = learner$packages,
param_set = learner$param_set) {
self$learner = learner = assert_learner(as_learner(learner, clone = TRUE))
super$initialize(
id = id,
task_type = task_type,
feature_types = feature_types,
packages = packages,
param_set = param_set,
man = "mlr3filters::mlr_filters_importance"
)
}
),
private = list(
.calculate = function(task, nfeat) {
learner = self$learner$clone(deep = TRUE)
learner = learner$train(task = task)
learner$importance()
}
)
)
#' @include mlr_filters.R
mlr_filters$add("importance", FilterImportance)