Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ResamplingRepeatedSpCVEnv #32

Merged
merged 4 commits into from
Jan 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
renv
renv.lock
inst/doc
man/figures/README*
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ VignetteBuilder:
RdMacros:
mlr3misc
Remotes:
mlr-org/mlr3misc,
rvalavi/blockCV
Encoding: UTF-8
LazyData: true
Expand All @@ -56,6 +57,7 @@ Collate:
'ResamplingSpCVCoords.R'
'ResamplingSpCVEnv.R'
'ResamplingRepeatedSpCVCoords.R'
'ResamplingRepeatedSpCVEnv.R'
'TaskClassifST.R'
'TaskRegrST.R'
'helper.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(autoplot,ResamplingRepeatedSpCVCoords)
S3method(autoplot,ResamplingRepeatedSpCVEnv)
S3method(autoplot,ResamplingSpCVBlock)
S3method(autoplot,ResamplingSpCVBuffer)
S3method(autoplot,ResamplingSpCVCoords)
S3method(autoplot,ResamplingSpCVEnv)
export(ResamplingRepeatedSpCVCoords)
export(ResamplingRepeatedSpCVEnv)
export(ResamplingSpCVBlock)
export(ResamplingSpCVBuffer)
export(ResamplingSpCVCoords)
Expand Down
10 changes: 10 additions & 0 deletions R/ResamplingRepeatedSpCVCoords.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ ResamplingRepeatedSpCVCoords = R6Class("ResamplingRepeatedSpCVCoords",

},

#' @description Translates iteration numbers to fold number.
#' @param iters `integer()`\cr
#' Iteration number.
folds = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %% as.integer(self$param_set$values$repeats)) + 1L
},

#' @description Translates iteration numbers to repetition number.
#' @param iters `integer()`\cr
#' Iteration number.
repeats = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
Expand All @@ -83,6 +89,10 @@ ResamplingRepeatedSpCVCoords = R6Class("ResamplingRepeatedSpCVCoords",
),

active = list(

#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
pv = self$param_set$values
as.integer(pv$repeats) * as.integer(pv$folds)
Expand Down
153 changes: 153 additions & 0 deletions R/ResamplingRepeatedSpCVEnv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#' @title Repeated Environmental Block Cross Validation Resampling
#'
#' @import mlr3
#'
#' @description Environmental Block Cross Validation. This strategy uses k-means
#' clustering to specify blocks of similar environmental conditions. Only
#' numeric features can be used. The `features` used for building blocks can
#' be specified in the `param_set`. By default, all numeric features are used.
#'
#' @references
#' \cite{mlr3spatiotempcv}{valavi2018}
#'
#' @export
#' @examples
#' library(mlr3)
#' task = tsk("ecuador")
#'
#' # Instantiate Resampling
#' rrcv = rsmp("repeated-spcv-env")
#' rrcv$param_set$values = list(folds = 4, repeats = 2)
#' rrcv$instantiate(task)
#'
#' # Individual sets:
#' rrcv$train_set(1)
#' rrcv$test_set(1)
#' intersect(rrcv$train_set(1), rrcv$test_set(1))
#'
#' # Internal storage:
#' rrcv$instance
ResamplingRepeatedSpCVEnv = R6Class("ResamplingRepeatedSpCVEnv",
inherit = mlr3::Resampling,

public = list(
#' @description
#' Create an "coordinate-based" repeated resampling instance.
#' @param id `character(1)`\cr
#' Identifier for the resampling strategy.
initialize = function(id = "repeated-spcv-env") {
ps = ParamSet$new(params = list(
ParamInt$new("repeats", lower = 1),
ParamInt$new("folds", lower = 1L, tags = "required")
))
ps$values = list(folds = 10L)
super$initialize(
id = id,
param_set = ps,
man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcv_env"
)

},

#' @description Translates iteration numbers to fold number.
#' @param iters `integer()`\cr
#' Iteration number.
folds = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %% as.integer(self$param_set$values$repeats)) + 1L
},

#' @description Translates iteration numbers to repetition number.
#' @param iters `integer()`\cr
#' Iteration number.
repeats = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
},

#' @description
#' Materializes fixed training and test splits for a given task.
#' @param task [Task]\cr
#' A task to instantiate.
instantiate = function(task) {

assert_task(task)
pv = self$param_set$values

# Set values to default if missing
if (is.null(pv$rows)) {
pv$rows = self$param_set$default[["rows"]]
}
if (is.null(pv$cols)) {
pv$cols = self$param_set$default[["cols"]]
}
if (is.null(pv$features)) {
pv$features = task$feature_names
}

# Remove non-numeric features, target and coordinates
columns = task$col_info[!id %in%
c(task$target_names, "x", "y")][type == "numeric"]

# Check for selected features that are not in task
diff = setdiff(pv$features, columns[, id])
if (length(diff) > 0) {
stopf("'spcv-env' requires numeric features for clustering.
Feature '%s' is either non-numeric or does not exist in the data.",
diff, wrap = TRUE)
}
columns = columns[id %in% pv$features]
columns = columns[, id]

data = task$data()[, columns, with = FALSE]

instance = private$.sample(task$row_ids, data)

self$instance = instance
self$task_hash = task$hash
invisible(self)
}
),

active = list(

#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
pv = self$param_set$values
as.integer(pv$repeats) * as.integer(pv$folds)
}
),

private = list(
.sample = function(ids, coords) {
pv = self$param_set$values
folds = as.integer(pv$folds)

map_dtr(seq_len(pv$repeats), function(i) {
data.table(row_id = ids, rep = i,
fold = kmeans(coords, centers = folds)$cluster
)
})
},

.get_train = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
ii = data.table(rep = rep, fold = seq_len(folds)[-fold])
self$instance[ii, "row_id", on = names(ii), nomatch = 0L][[1L]]
},

.get_test = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
ii = data.table(rep = rep, fold = fold)
self$instance[ii, "row_id", on = names(ii), nomatch = 0L][[1L]]
}
)
)
20 changes: 10 additions & 10 deletions R/ResamplingSpCVEnv.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
#'
#' @import mlr3
#'
#' @description Environmental Block Cross Validation. This strategy uses k-means
#' clustering to specify blocks of smilar environmental conditions. Only numeric
#' features can be used. The `features` used for building blocks can be
#' specified in the `param_set`. By default, all numeric features are used.
#' @description
#' Environmental Block Cross Validation. This strategy uses k-means clustering
#' to specify blocks of similar environmental conditions. Only numeric features
#' can be used. The `features` used for building blocks can be specified in the
#' `param_set`. By default, all numeric features are used.
#'
#' @references Valavi R, Elith J, Lahoz-Monfort JJ, Guillera-Arroita G. blockCV:
#' An r package for generating spatially or environmentally separated folds for
#' k-fold cross-validation of species distribution models. Methods Ecol Evol.
#' 2019; 10:225–232. https://doi.org/10.1111/2041-210X.13107
#' @references
#' \cite{mlr3spatiotempcv}{valavi2018}
#'
#' @export
#' @examples
Expand Down Expand Up @@ -80,8 +79,9 @@ ResamplingSpCVEnv = R6Class("ResamplingSpCVEnv", inherit = mlr3::Resampling,
# Check for selected features that are not in task
diff = setdiff(pv$features, columns[, id])
if (length(diff) > 0) {
stop(sprintf("'spcv-env' requires numeric features for clustering. Feature '%s' is either non-numeric or does not exist in the data",
diff))
stopf("'spcv-env' requires numeric features for clustering.
Feature '%s' is either non-numeric or does not exist in the data.",
diff, wrap = TRUE)
}
columns = columns[id %in% pv$features]
columns = columns[, id]
Expand Down
33 changes: 33 additions & 0 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ autoplot.ResamplingRepeatedSpCVCoords = function(
grid = grid)
}

#' @title Plot for Repeated Spatial Resampling
#'
#' @rdname autoplot_spatial_resampling
#' @export
#' @examples
#' #####
#' # RepeatedSpCVEnv
#' #####
#' \donttest{
#' task = tsk("ecuador")
#' resampling = rsmp("repeated-spcv-env", folds = 10, repeats = 2)
#' resampling$instantiate(task)
#' autoplot(resampling, task)
#' autoplot(resampling, task, 1)
#' autoplot(resampling, task, fold_id = 2, repeats_id = 2)
#' autoplot(resampling, task, c(1, 2, 3, 4))
#' }
autoplot.ResamplingRepeatedSpCVEnv = function(
object,
task,
fold_id = NULL,
repeats_id = 1,
grid = TRUE,
train_color = "#0072B5",
test_color = "#E18727",
...) {
autoplot_spatial(resampling = object,
task = task,
fold_id = fold_id,
repeats_id = repeats_id,
grid = grid)
}

autoplot_spatial = function(
resampling = NULL,
task = NULL,
Expand Down
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ register_mlr3 = function() {
mlr_resamplings$add("spcv-buffer", ResamplingSpCVBuffer)
mlr_resamplings$add("spcv-coords", ResamplingSpCVCoords)
mlr_resamplings$add("spcv-env", ResamplingSpCVEnv)

mlr_resamplings$add("repeated-spcv-coords", ResamplingRepeatedSpCVCoords)
mlr_resamplings$add("repeated-spcv-env", ResamplingRepeatedSpCVEnv)
}

}
Expand Down
3 changes: 2 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Currently, the following ones are implemented:
| Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `ResamplingSpCVCoords` | `rsmp("spcv-coords")` |
| Environmental Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `ResamplingSpCVEnv` | `rsmp("spcv-env")` |
| --- | --- | --- | --- | --- |
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Env Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `RepeatedResamplingSpCVEnv` | `rsmp("repeated-spcv-env")` |

## Spatial tasks

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Currently, the following ones are implemented:
| Environmental Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `ResamplingSpCVEnv` | `rsmp("spcv-env")` |
| — | — | — | — | — |
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Env Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `RepeatedResamplingSpCVEnv` | `rsmp("repeated-spcv-env")` |

## Spatial tasks

Expand Down
27 changes: 27 additions & 0 deletions man/ResamplingRepeatedSpCVCoords.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading