Skip to content

Commit

Permalink
set new default for rows and colds in spcv-block
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-s committed Sep 4, 2019
1 parent 5348ac5 commit 5727d57
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 15 deletions.
15 changes: 11 additions & 4 deletions R/ResamplingSpCVBlock.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#'
#' # Instantiate Resampling
#' rcv = rsmp("spcv-block")
#' rcv$param_set$values = list(folds = 4)
#' rcv$instantiate(task)
#'
#' # Individual sets:
Expand All @@ -43,8 +42,8 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
param_set = ParamSet$new(params = list(
ParamUty$new("stratify", default = NULL),
ParamInt$new("folds", lower = 1L, tags = "required"),
ParamInt$new("rows", lower = 1L, default = 2),
ParamInt$new("cols", lower = 1L, default = 2),
ParamInt$new("rows", lower = 1L, default = 4),
ParamInt$new("cols", lower = 1L, default = 4),
ParamInt$new("range", lower = 1L),
ParamFct$new("selection", levels = c("random", "systematic", "checkerboard"), default = "random")

Expand All @@ -56,7 +55,6 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",

assert_task(task)

# Check combination
if (!is.null(self$param_set$values$range) & (!is.null(self$param_set$values$rows) | !is.null(self$param_set$values$cols))) {
warning("Cols and rows are ignored. Range is used to generated blocks.")
}
Expand All @@ -72,6 +70,15 @@ ResamplingSpCVBlock = R6Class("ResamplingSpCVBlock",
self$param_set$values$selection = self$param_set$default[["selection"]]
}

# Check for valid combinations of rows, cols and folds
if ((self$param_set$values$rows*self$param_set$values$cols) < self$param_set$values$folds) {
stopf("'nrow' * 'ncol' needs to be larger than 'folds'.")
}

if (!is.null(self$param_set$values$rows) && !is.null(self$param_set$values$cols)) {
warning("Hyperparameters 'rows' and 'cols' not set. Using the default value of '4' set by 'mlr3spatiotemporal' for both which results in a grid of 16. You might want to set these values yourself during resampling construction.")
}

groups = task$groups
stratify = self$param_set$values$stratify

Expand Down
5 changes: 4 additions & 1 deletion R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ autoplot_spatial = function(resampling, task, fold_id = NULL, grid = TRUE) {
stop("More folds specified than stored in resampling.")
}
if (length(fold_id) == 1 && fold_id > resampling$iters) {
stop("Specified a fold id which is not available.")
stop("Specified a fold id which exceeds the total number of folds.")
}
if (any(fold_id > resampling$iters)) {
stop("Specified a fold id which exceeds the total number of folds.")
}
# Multiplot with train and test set
plot_list = list()
Expand Down
1 change: 0 additions & 1 deletion man/ResamplingSpCVBlock.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 8 additions & 9 deletions tests/testthat/test_autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@ vdiffr::expect_doppelganger("SpCVCoords - Fold 1-4", coords3)

# SpCVBlock --------------------------------------------------------------------

resa_Block = ResamplingSpCVBlock$new()
resa_Block$param_set$values = list(folds = 4)
resa_Block$instantiate(task)
Block1 = autoplot(resa_Block, task)
Block2 = autoplot(resa_Block, task, 1)
Block3 = autoplot(resa_Block, task, c(1, 2, 3, 4))
resa_block = ResamplingSpCVBlock$new()
resa_block$instantiate(task)
Block1 = autoplot(resa_block, task)
Block2 = autoplot(resa_block, task, 1)
Block3 = autoplot(resa_block, task, c(1, 2, 3, 4))

vdiffr::expect_doppelganger("SpCVBlock all test sets", Block1)
vdiffr::expect_doppelganger("SpCVBlock - Fold 1", Block2)
vdiffr::expect_doppelganger("SpCVBlock - Fold 1-4", Block3)

# these checks apply to all resampling methods.
expect_error(autoplot(resa_Block, task, 5))
expect_error(autoplot(resa_Block, task, c(1, 2, 3, 4, 5)))
expect_list(autoplot(resa_Block, task, c(1, 2, 3, 4), grid = FALSE))
expect_error(autoplot(resa_block, task, 20))
expect_error(autoplot(resa_block, task, c(1, 20)))
expect_list(autoplot(resa_block, task, c(1, 2, 3, 4), grid = FALSE))

# SpCVBuffer -------------------------------------------------------------------

Expand Down

0 comments on commit 5727d57

Please sign in to comment.