Skip to content

Commit

Permalink
Add ResamplingRepeatedSpCVEnv (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-s committed Jan 19, 2020
1 parent 34bf45e commit c798813
Show file tree
Hide file tree
Showing 31 changed files with 5,685 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
renv
renv.lock
inst/doc
man/figures/README*
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ VignetteBuilder:
RdMacros:
mlr3misc
Remotes:
mlr-org/mlr3misc,
rvalavi/blockCV
Encoding: UTF-8
LazyData: true
Expand All @@ -56,6 +57,7 @@ Collate:
'ResamplingSpCVCoords.R'
'ResamplingSpCVEnv.R'
'ResamplingRepeatedSpCVCoords.R'
'ResamplingRepeatedSpCVEnv.R'
'TaskClassifST.R'
'TaskRegrST.R'
'helper.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(autoplot,ResamplingRepeatedSpCVCoords)
S3method(autoplot,ResamplingRepeatedSpCVEnv)
S3method(autoplot,ResamplingSpCVBlock)
S3method(autoplot,ResamplingSpCVBuffer)
S3method(autoplot,ResamplingSpCVCoords)
S3method(autoplot,ResamplingSpCVEnv)
export(ResamplingRepeatedSpCVCoords)
export(ResamplingRepeatedSpCVEnv)
export(ResamplingSpCVBlock)
export(ResamplingSpCVBuffer)
export(ResamplingSpCVCoords)
Expand Down
10 changes: 10 additions & 0 deletions R/ResamplingRepeatedSpCVCoords.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ ResamplingRepeatedSpCVCoords = R6Class("ResamplingRepeatedSpCVCoords",

},

#' @description Translates iteration numbers to fold number.
#' @param iters `integer()`\cr
#' Iteration number.
folds = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %% as.integer(self$param_set$values$repeats)) + 1L
},

#' @description Translates iteration numbers to repetition number.
#' @param iters `integer()`\cr
#' Iteration number.
repeats = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
Expand All @@ -83,6 +89,10 @@ ResamplingRepeatedSpCVCoords = R6Class("ResamplingRepeatedSpCVCoords",
),

active = list(

#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
pv = self$param_set$values
as.integer(pv$repeats) * as.integer(pv$folds)
Expand Down
153 changes: 153 additions & 0 deletions R/ResamplingRepeatedSpCVEnv.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#' @title Repeated Environmental Block Cross Validation Resampling
#'
#' @import mlr3
#'
#' @description Environmental Block Cross Validation. This strategy uses k-means
#' clustering to specify blocks of similar environmental conditions. Only
#' numeric features can be used. The `features` used for building blocks can
#' be specified in the `param_set`. By default, all numeric features are used.
#'
#' @references
#' \cite{mlr3spatiotempcv}{valavi2018}
#'
#' @export
#' @examples
#' library(mlr3)
#' task = tsk("ecuador")
#'
#' # Instantiate Resampling
#' rrcv = rsmp("repeated-spcv-env")
#' rrcv$param_set$values = list(folds = 4, repeats = 2)
#' rrcv$instantiate(task)
#'
#' # Individual sets:
#' rrcv$train_set(1)
#' rrcv$test_set(1)
#' intersect(rrcv$train_set(1), rrcv$test_set(1))
#'
#' # Internal storage:
#' rrcv$instance
ResamplingRepeatedSpCVEnv = R6Class("ResamplingRepeatedSpCVEnv",
inherit = mlr3::Resampling,

public = list(
#' @description
#' Create an "coordinate-based" repeated resampling instance.
#' @param id `character(1)`\cr
#' Identifier for the resampling strategy.
initialize = function(id = "repeated-spcv-env") {
ps = ParamSet$new(params = list(
ParamInt$new("repeats", lower = 1),
ParamInt$new("folds", lower = 1L, tags = "required")
))
ps$values = list(folds = 10L)
super$initialize(
id = id,
param_set = ps,
man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcv_env"
)

},

#' @description Translates iteration numbers to fold number.
#' @param iters `integer()`\cr
#' Iteration number.
folds = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %% as.integer(self$param_set$values$repeats)) + 1L
},

#' @description Translates iteration numbers to repetition number.
#' @param iters `integer()`\cr
#' Iteration number.
repeats = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
},

#' @description
#' Materializes fixed training and test splits for a given task.
#' @param task [Task]\cr
#' A task to instantiate.
instantiate = function(task) {

assert_task(task)
pv = self$param_set$values

# Set values to default if missing
if (is.null(pv$rows)) {
pv$rows = self$param_set$default[["rows"]]
}
if (is.null(pv$cols)) {
pv$cols = self$param_set$default[["cols"]]
}
if (is.null(pv$features)) {
pv$features = task$feature_names
}

# Remove non-numeric features, target and coordinates
columns = task$col_info[!id %in%
c(task$target_names, "x", "y")][type == "numeric"]

# Check for selected features that are not in task
diff = setdiff(pv$features, columns[, id])
if (length(diff) > 0) {
stopf("'spcv-env' requires numeric features for clustering.
Feature '%s' is either non-numeric or does not exist in the data.",
diff, wrap = TRUE)
}
columns = columns[id %in% pv$features]
columns = columns[, id]

data = task$data()[, columns, with = FALSE]

instance = private$.sample(task$row_ids, data)

self$instance = instance
self$task_hash = task$hash
invisible(self)
}
),

active = list(

#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
pv = self$param_set$values
as.integer(pv$repeats) * as.integer(pv$folds)
}
),

private = list(
.sample = function(ids, coords) {
pv = self$param_set$values
folds = as.integer(pv$folds)

map_dtr(seq_len(pv$repeats), function(i) {
data.table(row_id = ids, rep = i,
fold = kmeans(coords, centers = folds)$cluster
)
})
},

.get_train = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
ii = data.table(rep = rep, fold = seq_len(folds)[-fold])
self$instance[ii, "row_id", on = names(ii), nomatch = 0L][[1L]]
},

.get_test = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
ii = data.table(rep = rep, fold = fold)
self$instance[ii, "row_id", on = names(ii), nomatch = 0L][[1L]]
}
)
)
20 changes: 10 additions & 10 deletions R/ResamplingSpCVEnv.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
#'
#' @import mlr3
#'
#' @description Environmental Block Cross Validation. This strategy uses k-means
#' clustering to specify blocks of smilar environmental conditions. Only numeric
#' features can be used. The `features` used for building blocks can be
#' specified in the `param_set`. By default, all numeric features are used.
#' @description
#' Environmental Block Cross Validation. This strategy uses k-means clustering
#' to specify blocks of similar environmental conditions. Only numeric features
#' can be used. The `features` used for building blocks can be specified in the
#' `param_set`. By default, all numeric features are used.
#'
#' @references Valavi R, Elith J, Lahoz-Monfort JJ, Guillera-Arroita G. blockCV:
#' An r package for generating spatially or environmentally separated folds for
#' k-fold cross-validation of species distribution models. Methods Ecol Evol.
#' 2019; 10:225–232. https://doi.org/10.1111/2041-210X.13107
#' @references
#' \cite{mlr3spatiotempcv}{valavi2018}
#'
#' @export
#' @examples
Expand Down Expand Up @@ -80,8 +79,9 @@ ResamplingSpCVEnv = R6Class("ResamplingSpCVEnv", inherit = mlr3::Resampling,
# Check for selected features that are not in task
diff = setdiff(pv$features, columns[, id])
if (length(diff) > 0) {
stop(sprintf("'spcv-env' requires numeric features for clustering. Feature '%s' is either non-numeric or does not exist in the data",
diff))
stopf("'spcv-env' requires numeric features for clustering.
Feature '%s' is either non-numeric or does not exist in the data.",
diff, wrap = TRUE)
}
columns = columns[id %in% pv$features]
columns = columns[, id]
Expand Down
33 changes: 33 additions & 0 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ autoplot.ResamplingRepeatedSpCVCoords = function(
grid = grid)
}

#' @title Plot for Repeated Spatial Resampling
#'
#' @rdname autoplot_spatial_resampling
#' @export
#' @examples
#' #####
#' # RepeatedSpCVEnv
#' #####
#' \donttest{
#' task = tsk("ecuador")
#' resampling = rsmp("repeated-spcv-env", folds = 10, repeats = 2)
#' resampling$instantiate(task)
#' autoplot(resampling, task)
#' autoplot(resampling, task, 1)
#' autoplot(resampling, task, fold_id = 2, repeats_id = 2)
#' autoplot(resampling, task, c(1, 2, 3, 4))
#' }
autoplot.ResamplingRepeatedSpCVEnv = function(
object,
task,
fold_id = NULL,
repeats_id = 1,
grid = TRUE,
train_color = "#0072B5",
test_color = "#E18727",
...) {
autoplot_spatial(resampling = object,
task = task,
fold_id = fold_id,
repeats_id = repeats_id,
grid = grid)
}

autoplot_spatial = function(
resampling = NULL,
task = NULL,
Expand Down
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ register_mlr3 = function() {
mlr_resamplings$add("spcv-buffer", ResamplingSpCVBuffer)
mlr_resamplings$add("spcv-coords", ResamplingSpCVCoords)
mlr_resamplings$add("spcv-env", ResamplingSpCVEnv)

mlr_resamplings$add("repeated-spcv-coords", ResamplingRepeatedSpCVCoords)
mlr_resamplings$add("repeated-spcv-env", ResamplingRepeatedSpCVEnv)
}

}
Expand Down
3 changes: 2 additions & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Currently, the following ones are implemented:
| Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `ResamplingSpCVCoords` | `rsmp("spcv-coords")` |
| Environmental Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `ResamplingSpCVEnv` | `rsmp("spcv-env")` |
| --- | --- | --- | --- | --- |
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Env Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `RepeatedResamplingSpCVEnv` | `rsmp("repeated-spcv-env")` |

## Spatial tasks

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Currently, the following ones are implemented:
| Environmental Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `ResamplingSpCVEnv` | `rsmp("spcv-env")` |
||||||
| Repeated Spatial CV | [sperrorest](https://github.com/giscience-fsu/sperrorest) | Brenning 2012 | `RepeatedResamplingSpCVCoords` | `rsmp("repeated-spcv-coords")` |
| Repeated Env Blocking | [blockCV](https://github.com/rvalavi/blockCV) | Valavi 2019 | `RepeatedResamplingSpCVEnv` | `rsmp("repeated-spcv-env")` |

## Spatial tasks

Expand Down
27 changes: 27 additions & 0 deletions man/ResamplingRepeatedSpCVCoords.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c798813

Please sign in to comment.