-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#' @title Correlation Filter | ||
#' | ||
#' @usage NULL | ||
#' @name mlr_filters_find_correlation | ||
#' @format [R6::R6Class] inheriting from [Filter]. | ||
#' @include Filter.R | ||
#' | ||
#' @section Construction: | ||
#' ``` | ||
#' FilterFindCorrelation$new() | ||
#' mlr_filters$get("findCorrelation") | ||
#' flt("findCorrelation") | ||
#' ``` | ||
#' | ||
#' @description | ||
#' Simple filter emulating `caret::findCorrelation(exact = FALSE)`. | ||
#' | ||
#' This gives each feature a score between 0 and 1 that is *one minus* the | ||
#' cutoff value for which it is excluded when using [caret::findCorrelation()]. | ||
#' The negative is used because [caret::findCorrelation()] excludes everything | ||
#' *above* a cutoff, while filters exclude everything below a cutoff. | ||
#' Here the filter scores are shifted by +1 to get positive values for to align | ||
#' with the way other filters work. | ||
#' | ||
#' Subsequently `caret::findCorrelation(cutoff = 0.9)` lists the same features | ||
#' that are excluded with `FilterFindCorrelation` at score 0.1 (= 1 - 0.9). | ||
#' | ||
#' @family Filter | ||
#' @template seealso_filter | ||
#' @export | ||
#' @examples | ||
#' ## Pearson (default) | ||
#' task = mlr3::tsk("mtcars") | ||
#' filter = flt("findCorrelation") | ||
#' filter$calculate(task) | ||
#' as.data.table(filter) | ||
#' | ||
#' ## Spearman | ||
#' filter = flt("findCorrelation", method = "spearman") | ||
#' filter$calculate(task) | ||
#' as.data.table(filter) | ||
FilterFindCorrelation = R6Class("FilterFindCorrelation", inherit = Filter, | ||
public = list( | ||
initialize = function() { | ||
super$initialize( | ||
id = "findCorrelation", | ||
packages = "stats", | ||
feature_types = c("integer", "numeric"), | ||
task_type = mlr_reflections$task_types$type, # basically any task | ||
param_set = ParamSet$new(list( | ||
ParamFct$new("use", default = "everything", | ||
levels = c("everything", "all.obs", "complete.obs", "na.or.complete", "pairwise.complete.obs")), | ||
ParamFct$new("method", default = "pearson", | ||
levels = c("pearson", "kendall", "spearman")) | ||
)) | ||
) | ||
}, | ||
|
||
calculate_internal = function(task, nfeat) { | ||
|
||
fn = task$feature_names | ||
pv = self$param_set$values | ||
cm = invoke(stats::cor, | ||
x = task$data(cols = fn), | ||
.args = pv) | ||
cm = abs(cm) | ||
# a feature is removed as soon as it is in the higher average correlation | ||
# col in a pair (note: tie broken by removing /later/ feature first) | ||
avg_cor = colMeans(cm) | ||
# decreasing = TRUE to emulate tie breaking | ||
avg_cor_order = order(avg_cor, decreasing = TRUE) | ||
cm = cm[avg_cor_order, avg_cor_order, drop = FALSE] | ||
# Rows / Columns of cm are now ordered by correlation mean, highest first. | ||
# A feature i is excluded as soon as a lower-average-correlation feature | ||
# has correlation with i > cutoff. This means the cutoff at which i is | ||
# excluded is the max of the correlation with all lower-avg-cor features. | ||
# Therefore we look for the highest feature correlation col-wise in the | ||
# lower triangle of the ordered cm. | ||
|
||
# the lowest avg col feature is never removed by caret, so its cutoff is | ||
# 0. | ||
cm[upper.tri(cm, diag = TRUE)] = 0 | ||
# The following has the correct names and values, BUT we need scores in | ||
# reverse order. Shift by 1 to get positive values. | ||
1 - apply(cm, 2, max) | ||
} | ||
) | ||
) | ||
|
||
#' @include mlr_filters.R | ||
mlr_filters$add("findCorrelation", FilterFindCorrelation) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
context("FilterFindCorrelation") | ||
|
||
test_that("FilterImportance", { | ||
task = mlr3::mlr_tasks$get("sonar") | ||
equalcor = cbind(a = rep(c(1, 0, 0, 0), task$nrow / 4), b = c(0, 1, 0, 0), c = c(0, 0, 1, 0), d = c(0, 0, 0, 1), e = c(0.1, -0.1, 0.1, 0.99), f = c(-0.1, 0.1, 0.1, 0.99)) | ||
task$cbind(as.data.frame(equalcor)) | ||
data = task$data(cols = task$feature_names) | ||
cm = cor(data) | ||
checkpoints = (0:100) / 100 | ||
remove_caret = lapply(checkpoints, caret::findCorrelation, x = cm, exact = FALSE) | ||
f = FilterFindCorrelation$new() | ||
f$calculate(task) | ||
remove_filter = lapply(checkpoints, function(cutoff) match(names(f$scores)[f$scores < 1 - cutoff], task$feature_names)) | ||
mapply(expect_set_equal, remove_caret, remove_filter) | ||
|
||
}) |