-
-
Notifications
You must be signed in to change notification settings - Fork 7
/
FilterVariableImportance.R
48 lines (45 loc) · 1.47 KB
/
FilterVariableImportance.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#' @title Variable Importance Filter
#'
#' @aliases mlr_filters_variable_importance
#' @format [R6::R6Class] inheriting from [Filter].
#' @include Filter.R
#'
#' @description
#' Variable Importance filter.
#' Takes a [mlr3::Learner] which supports retrieving the variable importance (property "importance"),
#' fits the model and uses the importance values as filter scores.
#'
#'
#' @family Filter
#' @export
#' @examples
#' task = mlr3::mlr_tasks$get("iris")
#' learner = mlr3::mlr_learners$get("classif.rpart")
#' filter = FilterVariableImportance$new(learner = learner)
#' filter$calculate(task)
#' head(as.data.table(filter), 3)
FilterVariableImportance = R6Class("FilterVariableImportance", inherit = Filter,
public = list(
learner = NULL,
initialize = function(id = "variable_importance", learner) {
self$learner = assert_learner(learner, properties = "importance")
super$initialize(
id = id,
packages = learner$packages,
feature_types = learner$feature_types,
task_type = learner$task_type,
param_set = learner$param_set$clone(deep = TRUE)
)
}
),
private = list(
.calculate = function(task) {
learner = self$learner$clone(deep = TRUE)
learner$param_set$values = self$param_set$values
e = Experiment$new(task = task, learner = learner)$train()
importance = e$learner$importance()
fn = task$feature_names
insert_named(set_names(numeric(length(fn)), fn), importance)
}
)
)