Skip to content

Commit

Permalink
FilterFindCorrelation (#62)
Browse files Browse the repository at this point in the history
* new filter `FilterFindCorrelation` (#62, @mb706)
  • Loading branch information
mb706 committed Feb 24, 2020
1 parent 551e264 commit 8c9bc21
Show file tree
Hide file tree
Showing 23 changed files with 196 additions and 1 deletion.
8 changes: 7 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ Authors@R:
family = "Bischl",
role = "aut",
email = "bernd_bischl@gmx.net",
comment = c(ORCID = "0000-0001-6002-6980")))
comment = c(ORCID = "0000-0001-6002-6980")),
person(given = "Martin",
family = "Binder",
role = "aut",
email = "mlr.developer@mb706.com"))
Description: Extends 'mlr3' with filter methods for feature
selection. Besides standalone filter methods built-in methods of any
machine-learning algorithm are supported. Partial scoring of
Expand All @@ -37,6 +41,7 @@ Imports:
R6
Suggests:
care,
caret,
FSelectorRcpp,
lgr,
mlr3learners,
Expand All @@ -57,6 +62,7 @@ Collate:
'FilterCarScore.R'
'FilterCorrelation.R'
'FilterDISR.R'
'FilterFindCorrelation.R'
'FilterImportance.R'
'FilterInformationGain.R'
'FilterJMI.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(FilterCMIM)
export(FilterCarScore)
export(FilterCorrelation)
export(FilterDISR)
export(FilterFindCorrelation)
export(FilterImportance)
export(FilterInformationGain)
export(FilterJMI)
Expand Down
91 changes: 91 additions & 0 deletions R/FilterFindCorrelation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#' @title Correlation Filter
#'
#' @usage NULL
#' @name mlr_filters_find_correlation
#' @format [R6::R6Class] inheriting from [Filter].
#' @include Filter.R
#'
#' @section Construction:
#' ```
#' FilterFindCorrelation$new()
#' mlr_filters$get("findCorrelation")
#' flt("findCorrelation")
#' ```
#'
#' @description
#' Simple filter emulating `caret::findCorrelation(exact = FALSE)`.
#'
#' This gives each feature a score between 0 and 1 that is *one minus* the
#' cutoff value for which it is excluded when using [caret::findCorrelation()].
#' The negative is used because [caret::findCorrelation()] excludes everything
#' *above* a cutoff, while filters exclude everything below a cutoff.
#' Here the filter scores are shifted by +1 to get positive values for to align
#' with the way other filters work.
#'
#' Subsequently `caret::findCorrelation(cutoff = 0.9)` lists the same features
#' that are excluded with `FilterFindCorrelation` at score 0.1 (= 1 - 0.9).
#'
#' @family Filter
#' @template seealso_filter
#' @export
#' @examples
#' ## Pearson (default)
#' task = mlr3::tsk("mtcars")
#' filter = flt("findCorrelation")
#' filter$calculate(task)
#' as.data.table(filter)
#'
#' ## Spearman
#' filter = flt("findCorrelation", method = "spearman")
#' filter$calculate(task)
#' as.data.table(filter)
FilterFindCorrelation = R6Class("FilterFindCorrelation", inherit = Filter,
public = list(
initialize = function() {
super$initialize(
id = "findCorrelation",
packages = "stats",
feature_types = c("integer", "numeric"),
task_type = mlr_reflections$task_types$type, # basically any task
param_set = ParamSet$new(list(
ParamFct$new("use", default = "everything",
levels = c("everything", "all.obs", "complete.obs", "na.or.complete", "pairwise.complete.obs")),
ParamFct$new("method", default = "pearson",
levels = c("pearson", "kendall", "spearman"))
))
)
},

calculate_internal = function(task, nfeat) {

fn = task$feature_names
pv = self$param_set$values
cm = invoke(stats::cor,
x = task$data(cols = fn),
.args = pv)
cm = abs(cm)
# a feature is removed as soon as it is in the higher average correlation
# col in a pair (note: tie broken by removing /later/ feature first)
avg_cor = colMeans(cm)
# decreasing = TRUE to emulate tie breaking
avg_cor_order = order(avg_cor, decreasing = TRUE)
cm = cm[avg_cor_order, avg_cor_order, drop = FALSE]
# Rows / Columns of cm are now ordered by correlation mean, highest first.
# A feature i is excluded as soon as a lower-average-correlation feature
# has correlation with i > cutoff. This means the cutoff at which i is
# excluded is the max of the correlation with all lower-avg-cor features.
# Therefore we look for the highest feature correlation col-wise in the
# lower triangle of the ordered cm.

# the lowest avg col feature is never removed by caret, so its cutoff is
# 0.
cm[upper.tri(cm, diag = TRUE)] = 0
# The following has the correct names and values, BUT we need scores in
# reverse order. Shift by 1 to get positive values.
1 - apply(cm, 2, max)
}
)
)

#' @include mlr_filters.R
mlr_filters$add("findCorrelation", FilterFindCorrelation)
1 change: 1 addition & 0 deletions man/Filter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_anova.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_auc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_carscore.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_cmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_correlation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_disr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 63 additions & 0 deletions man/mlr_filters_find_correlation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_information_gain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_jmi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_jmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_kruskal_test.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_mim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_mrmr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_njmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_performance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_variable_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_variance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test_FilterFindCorrelation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
context("FilterFindCorrelation")

test_that("FilterImportance", {
task = mlr3::mlr_tasks$get("sonar")
equalcor = cbind(a = rep(c(1, 0, 0, 0), task$nrow / 4), b = c(0, 1, 0, 0), c = c(0, 0, 1, 0), d = c(0, 0, 0, 1), e = c(0.1, -0.1, 0.1, 0.99), f = c(-0.1, 0.1, 0.1, 0.99))
task$cbind(as.data.frame(equalcor))
data = task$data(cols = task$feature_names)
cm = cor(data)
checkpoints = (0:100) / 100
remove_caret = lapply(checkpoints, caret::findCorrelation, x = cm, exact = FALSE)
f = FilterFindCorrelation$new()
f$calculate(task)
remove_filter = lapply(checkpoints, function(cutoff) match(names(f$scores)[f$scores < 1 - cutoff], task$feature_names))
mapply(expect_set_equal, remove_caret, remove_filter)

})

0 comments on commit 8c9bc21

Please sign in to comment.