Skip to content

Commit

Permalink
Merge c7c39fd into 366cd11
Browse files Browse the repository at this point in the history
  • Loading branch information
jakob-r committed Oct 12, 2018
2 parents 366cd11 + c7c39fd commit c2ea8a5
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 2 deletions.
79 changes: 79 additions & 0 deletions R/ParamSet.R
Expand Up @@ -198,6 +198,85 @@ ParamSet = R6Class(
}
},

# in: * ids (character)
# ids of ParamSimple
# * fix (named list)
# names = ids of ParamSimple
# values = values of respective param
# out: ParamSet
subset = function(ids = NULL, fix = NULL) {
if (is.null(ids)) {
keep_ids = self$ids
} else {
assert_subset(ids, self$ids)
keep_ids = ids
}
if (!is.null(fix)) {
if (any(names(fix) %in% ids)) {
stop("You cannot keep ids and fix them at the same time!")
}
assert_list(fix, names = "named")
keep_ids = setdiff(keep_ids, names(fix))
}
# if we have fixed parameters we have to supply them to the trafo function in case there are needed there.
new_trafo = self$trafo
if (!is.null(fix) && !is.null(new_trafo)) {
assert_list(fix, names = "named")
old_trafo = force(self$trafo)
new_trafo = function(x, dict, tags) {
x = cbind(x, as.data.table(fix))
res = old_trafo(x, dict, tags)
res = x[, !names(fix), with = FALSE]
return(res)
}
}
# if we have fixed parameters we can substitute them in the restriction quote
new_restriction = self$restriction
if (!is.null(fix) && !is.null(new_restriction)) {
new_restriction = substituteDirect(new_restriction, fix)
}
ParamSet$new(
id = paste0(self$id,"_subset"),
handle = self$handle$clone(),
params = self$params[keep_ids],
dictionary = as.list(self$dictionary),
tags = self$tags,
restriction = new_restriction,
trafo = new_trafo)
},

combine = function(param_set) {
if (self$length == 0) {
return(param_set$clone())
} else if (param_set$length == 0) {
return(self$clone())
}
if (length(intersect(self$ids, param_set$ids)) > 0) {
stop ("Combine failed, because new param_set has at least one Param with the same id as in this ParamSet.")
}
new_restriction = self$restriction %??% param_set$restriction
if (!is.null(self$restriction) && !is.null(param_set$restriction)) {
new_restriction = substitute((a) && (b), list(a = self$restriction %??% TRUE, b = param_set$restriction %??% TRUE))
}
new_trafo = self$trafo %??% param_set$trafo
if (!is.null(self$trafo) && !is.null(param_set$trafo)) {
new_trafo = function(x, dict, tags) {
x = self$trafo(x, dict, tags)
x = param_set$trafo(x, dict, tags)
return(x)
}
}
ParamSet$new(
id = paste0(self$id, "_", param_set$id),
handle = self$handle$clone(), #FIXME: If the handle is actually used this might not be a good idea, throw error?
params = c(self$params, param_set$params),
dictionary = c(as.list(self$dictionary), as.list(param_set$dictionary)),
tags = union(self$tags, param_set$tags),
restriction = new_restriction,
trafo = new_trafo
)
},

print = function(...) {
cat("ParamSet:", self$id, "\n")
cat("Parameters:", "\n")
Expand Down
10 changes: 9 additions & 1 deletion R/design_to_list.R
@@ -1,5 +1,13 @@
#' @title Design to List
#' @description
#' Converts a design or any `data.table` to a list where each list element contains one row.
#'
#' @param design [`data.table`]:
#' The `data.table` that should be converted
#'
#' @return `list`
#' @export
design_to_list = function(design) {
assert_data_table(design)
.mapply(list, design, list())
}
}
116 changes: 115 additions & 1 deletion tests/testthat/test_ParamSet.R
Expand Up @@ -10,7 +10,7 @@ test_that("methods and active bindings work", {
th_paramset_numeric,
th_paramset_trafo,
th_paramset_trafo_dictionary
)
)
for (ps in ps_list) {
if (ps$id == "th_paramset_full") {
expect_equal(ps$ids, c('th_param_int', 'th_param_real', 'th_param_categorical', 'th_param_flag'))
Expand Down Expand Up @@ -101,3 +101,117 @@ test_that("repeated params in ParamSet works", {
xs_l = design_to_list(xs_t)
expect_list(xs_l, len = 10)
})

test_that("param subset in ParamSet works", {
# Define all the different subsets we want to try:
configs = list(
list(
ps = th_paramset_full,
ids = c("th_param_int", "th_param_flag"),
expected_ids = c("th_param_int", "th_param_flag"),
fix = NULL
),
list(
ps = th_paramset_full,
expected_ids = c("th_param_real", "th_param_categorical", "th_param_flag"),
fix = list("th_param_int" = 1L)
),
list(
ps = th_paramset_trafo,
ids = c("th_param_int"),
expected_ids = c("th_param_int"),
fix = list("th_param_real" = 1)
),
list(
ps = th_paramset_trafo,
ids = NULL,
expected_ids = c("th_param_int"),
fix = list("th_param_real" = 1)
),
list(
ps = th_paramset_trafo_dictionary,
ids = NULL,
expected_ids = c("th_param_int"),
fix = list("th_param_real" = 1)
),
list(
ps = th_paramset_restricted,
ids = NULL,
expected_ids = c("th_param_int", "th_param_categorical"),
fix = list("th_param_real" = 1)
)
)
# Test the different combinations:
for (conf in configs) {
paramset_sub = conf$ps$subset(ids = conf$ids, fix = conf$fix)
expect_equal(paramset_sub$ids, conf$expected_ids)
x = paramset_sub$sample(2)
expect_set_equal(colnames(x), conf$expected_ids)
expect_true(paramset_sub$check(x[1,]))
expect_true(paramset_sub$check(x[2,]))
x_trafo = paramset_sub$transform(x)


x = paramset_sub$sample(1)
expect_set_equal(colnames(x), conf$expected_ids)
expect_true(paramset_sub$check(x))
x_trafo = paramset_sub$transform(x)
}
})

test_that("Combine of ParamSet work", {
# define some ParamSets we will join to the th_ ones
new_param_sets = list(
normal = ParamSet$new(
id = "new_param_set",
params = list(
ParamReal$new("new_int", lower = 0L, upper = 10L)
)
),
trafo = ParamSet$new(
id = "new_param_set_trafo",
params = list(
ParamReal$new("new_real", lower = 0, upper = 10)
),
trafo = function(x, dict, tags) {
x$new_real = sqrt(x$new_real)
return(x)
}
),
restriction = ParamSet$new(
id = "new_param_set_requires",
params = list(
ParamReal$new("new_real", lower = 0, upper = 10),
ParamReal$new("new_int", lower = 0L, upper = 10L)
),
restriction = quote(new_real>=new_int)
)
)

ps_list = list(
th_paramset_empty,
th_paramset_full,
th_paramset_repeated,
th_paramset_restricted,
th_paramset_numeric,
th_paramset_trafo,
th_paramset_trafo_dictionary
)

for (ps in ps_list) {
for(ps_new in new_param_sets) {
ps_comb1 = ps$combine(ps_new)
ps_comb2 = ps_new$combine(ps)
expect_set_equal(ps_comb1$ids, ps_comb1$ids)
expect_set_equal(ps_comb1$ids, c(ps$ids, ps_new$ids))
x = ps_comb1$sample(1)
expect_data_table(x)
expect_true(ps_comb1$check(x))
expect_true(ps_comb2$check(x))
xt1 = ps_comb1$transform(x)
xt2 = ps_comb2$transform(x)[, colnames(xt1), with = FALSE]
expect_equal(xt1, xt2)
}
}

})

0 comments on commit c2ea8a5

Please sign in to comment.