Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
berndbischl committed Dec 9, 2018
1 parent 597c3c8 commit b9117f8
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 66 deletions.
5 changes: 3 additions & 2 deletions R/ParamDbl.R
Expand Up @@ -33,7 +33,7 @@ ParamDbl = R6Class("ParamDbl", inherit = Parameter,

denorm_vector = function(x) {
assert_true(self$has_finite_bounds)
self$range[1] + x * diff(self$range)
self$lower + x * self$span
}

),
Expand All @@ -43,7 +43,8 @@ ParamDbl = R6Class("ParamDbl", inherit = Parameter,
center = function() {
assert_true(self$has_finite_bounds)
(self$lower + self$upper) / 2
}
},
span = function() self$upper - self$lower
),

private = list(
Expand Down
3 changes: 1 addition & 2 deletions R/ParamFct.R
Expand Up @@ -37,8 +37,7 @@ ParamFct = R6Class(
}
),
active = list(
nlevels = function() length(self$values),
has_finite_bounds = function() TRUE
nlevels = function() length(self$values)
),
private = list(
get_range_string = function() sprintf("{%s}", paste0(self$values, collapse = ",")),
Expand Down
4 changes: 1 addition & 3 deletions R/ParamLgl.R
Expand Up @@ -15,7 +15,7 @@ ParamLgl = R6Class("ParamLgl",
storage_type = "logical",
lower = NA_real_,
upper = NA_real_,
values = c("TRUE", "FALSE"),
values = NULL,
checker = function(x) check_flag(x),
special_vals = special_vals,
default = default,
Expand All @@ -28,8 +28,6 @@ ParamLgl = R6Class("ParamLgl",
}
),
active = list(
has_finite_bounds = function() TRUE,
values = function() c(TRUE, FALSE),
nlevels = function() 2L
),

Expand Down
50 changes: 17 additions & 33 deletions R/ParamSet.R
Expand Up @@ -70,6 +70,8 @@ ParamSet = R6Class("ParamSet",
initialize = function(params = list(), id = "paramset", tags = NULL, trafo = NULL) {
assert_list(params, types = "Parameter")
self$data = rbindlist(map(params, "data"))
# we set index not key, so we dont resort the table
setindex(self$data, "id")
# names(params) = map_chr(params, "id") # ensure we have a named list, with par ids
self$id = assert_string(id)
self$trafo = assert_function(trafo, args = c("x", "tags"), null.ok = TRUE)
Expand All @@ -81,14 +83,6 @@ ParamSet = R6Class("ParamSet",
}
},

denorm = function(x) {
assert_list(x, names = 'strict')
assert_set_equal(names(x), self$ids)
xs = lapply(self$ids, function(id) self$params[[id]]$denorm(x = x[id]))
names(xs) = NULL
as.data.table(xs)
},

# list --> list, named
transform = function(x) {
x = ensure_data_table(x)
Expand All @@ -104,9 +98,6 @@ ParamSet = R6Class("ParamSet",
return(xs)
},




# FIXME: subset und fix trennen

# in: * ids (character)
Expand Down Expand Up @@ -174,24 +165,16 @@ ParamSet = R6Class("ParamSet",
},

# check function that checks the whole param set by simply iterating
check = function(x, na.ok = FALSE, null.ok = FALSE) {
assert_set_equal(names(x), self$ids)
if (is.data.table(x)) x = as.list(x)
assert_list(x, names = "named")
for (par_name in names(x)) {
res = self$params[[par_name]]$check(x[[par_name]], na.ok = na.ok, null.ok = null.ok)
if(!isTRUE(res)) return(res)
}
return(res)
check = function(xs) {
ids = self$ids
assert_list(xs)
assert_names(names(xs), permutation.of = ids)
all(map_lgl(ids, function(id) self$get_param(id)$check(xs[[id]])))
},

test = function(...) {
makeTestFunction(self$check)(...)
},
test = function(xs) makeTestFunction(self$check)(xs),

assert = function(...) {
makeAssertionFunction(self$check)(...)
},
assert = function(xs) makeAssertionFunction(self$check)(xs),

add_dependency = function(dep) {
assert_r6(dep, "Dependency")
Expand All @@ -218,6 +201,12 @@ ParamSet = R6Class("ParamSet",
cat("Trafo is set:", "\n")
print(self$trafo)
}
},

get_param = function(id) {
assert_choice(id, self$ids)
r = self$data[id, on = "id"] # index single row by id
new_param_from_dt(r)
}
),

Expand All @@ -228,13 +217,8 @@ ParamSet = R6Class("ParamSet",
storage_types = function() private$get_col_with_idnames("storage_type"),
lowers = function() private$get_col_with_idnames("lower"),
uppers = function() private$get_col_with_idnames("upper"),
values = function() private$get_col_with_idnames("values")
# FIXME: reeanable?
# nlevels = function() map_int(self$params, function(param) param$nlevels %??% NA_integer_),
# FIXME: reeanable?
# has_finite_bounds = function() all(map_lgl(self$params, function(param) param$has_finite_bounds)),
# FIXME: reeanable?
# member_tags = function() lapply(self$params, function(param) param$tags)
values = function() private$get_col_with_idnames("values"),
tags = function() private$get_col_with_idnames("tags")
),

private = list(
Expand Down
26 changes: 25 additions & 1 deletion R/Parameter.R
Expand Up @@ -67,7 +67,8 @@ Parameter = R6Class("Parameter",

test = function(x) makeTestFunction(self$check)(x),

denorm = function(x) as_dt_cols(self$denorm_vector(x[[self$id]]), self$id),
# FIXME: what is this? remove?
# denorm = function(x) as_dt_cols(self$denorm_vector(x[[self$id]]), self$id),

denorm_vector = function(x) {
stop("denorm function not implemented!")
Expand Down Expand Up @@ -104,3 +105,26 @@ Parameter = R6Class("Parameter",
get_type_string = function() stop("abstract")
)
)


# private factory methods, creates a param from a single dt row
new_param_from_dt = function(dt) {
# get pclass constructor from namespace then call it on all (relevent) entries from the dt row
p = as.list(dt)
cl = getFromNamespace(p$pclass, "paradox")
# remove pclass and storage type, as the are not passed to contructor
p$pclass = NULL; p$storage_type = NULL
# FIXME: this is not that perfect code:
# - we untangle all list-cols
# - we remove all NULL or NA entries, as they cannot be passed to the constructor,
# NB: we might also handle this thru a reflection table, where we "know" which elements the Param-constructors take
p = map_if(p, is.list, function(x) x[[1L]])
p = Filter(function(x) !is.null(x) && !is.na(x), p)
do.call(cl$new, p)
}






2 changes: 1 addition & 1 deletion R/generate_design_grid.R
Expand Up @@ -28,7 +28,7 @@ generate_design_grid = function(param_set, resolution = NULL, param_resolutions
}

grid_vec = lapply(param_resolutions, function(r) seq(0, 1, length.out = r))
res = imap(grid_vec, function(value, id) unique(param_set$params[[id]]$denorm_vector(x = value)))
res = imap(grid_vec, function(value, id) unique(param_set$get_param(id)$denorm_vector(x = value)))
res = do.call(CJ, res)

return(res)
Expand Down
15 changes: 5 additions & 10 deletions R/generate_design_lhs.R
Expand Up @@ -14,15 +14,10 @@ generate_design_lhs = function(param_set, n, lhs_function = lhs::maximinLHS) {
n = assert_count(n, positive = TRUE, coerce = TRUE)
assert_function(lhs_function, args = c("n", "k"))

lhs_des = lhs_function(n, k = param_set$length)

# converts the LHS output to values of the parameters
sample_converter = function(lhs_des) {
vec_cols = lapply(seq_col(lhs_des), function(z) lhs_des[, z, drop = TRUE])
names(vec_cols) = param_set$ids
param_set$denorm(vec_cols)
}

sample_converter(lhs_des)
ids = param_set$ids
d = lhs_function(n, k = param_set$length)
colnames(d) = ids
d = map_dtc(ids, function(id) param_set$get_param(id)$denorm_vector(d[, id]))
set_names(d, ids)
}

2 changes: 0 additions & 2 deletions R/repeat_param.R
Expand Up @@ -11,8 +11,6 @@
#' The parameter that should be repeated_
#' @return List of Parameters
#' @export

#' FIXME return a param set
repeatParam = function(n = 1L, param) {
assert_int(n)
assert_r6(param, "Parameter")
Expand Down
5 changes: 3 additions & 2 deletions R/sampler.R
Expand Up @@ -9,7 +9,7 @@ Sampler = R6Class("Sampler",
# params.cl allows asserting params of only a certain type, vector of multiple entries is OK
initialize = function(param_set, params.cl = "Parameter") {
assert_r6(param_set, "ParamSet")
assert_list(param_set$params, types = params.cl)
assert_subset(param_set$pclasses, params.cl)
self$param_set = param_set
},

Expand All @@ -28,7 +28,8 @@ Sampler1D = R6Class("Sampler1D", inherit = Sampler,
),

active = list(
param = function() self$param_set$params[[1L]]
# retrieve the only param in the set, return Parameter object
param = function() self$param_set$get_param(self$param_set$ids[1L])
),

private = list(
Expand Down
File renamed without changes.
1 change: 0 additions & 1 deletion tests/testthat/test_ParamFct.R
Expand Up @@ -4,7 +4,6 @@ test_that("test if ParamFct constructor works", {
p = ParamFct$new(id = "test", values = c("a", "b"))
expect_equal(p$values, c("a", "b"))
expect_equal(p$nlevels, 2L)
expect_true(p$has_finite_bounds)

# we dont allow NAs as values
expect_error(ParamFct$new(id = "test", values = c("a", NA)))
Expand Down
2 changes: 0 additions & 2 deletions tests/testthat/test_ParamLgl.R
Expand Up @@ -3,8 +3,6 @@ context("ParamLgl")
test_that("constructor works", {
p = ParamLgl$new(id = "test")
expect_equal(p$id, "test")
expect_equal(p$values, c(TRUE, FALSE))
expect_true(p$has_finite_bounds)
expect_equal(p$nlevels, 2L)
})

Expand Down
16 changes: 9 additions & 7 deletions tests/testthat/test_generate_design.R
Expand Up @@ -34,13 +34,15 @@ test_that("generate_design_lhs", {
)

for (ps in ps_list) {
xl = generate_design_lhs(ps, 10)
expect_data_table(xl, nrows = 10, any.missing = FALSE)
expect_true(all(xl[, ps$test(.SD), by = seq_len(nrow(xl))]$V1))
xlt = ps$transform(xl)
expect_data_table(xlt, nrows = 10)
xltl = design_to_list(xlt)
expect_list(xltl, len = 10)
d = generate_design_lhs(ps, 10)
expect_data_table(d, nrows = 10, any.missing = FALSE)
xs = design_to_list(d)
all(map_lgl(xs, ps$test))
# FIXME: the next lines should not be here, they test transform and design_to_list
# xlt = ps$transform(xl)
# expect_data_table(xlt, nrows = 10)
# xltl = design_to_list(xlt)
# expect_list(xltl, len = 10)
}
})

0 comments on commit b9117f8

Please sign in to comment.