Skip to content

Commit

Permalink
implement trafo and check
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Jul 6, 2018
1 parent 3872f79 commit 7f45285
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
6 changes: 3 additions & 3 deletions R/ParamSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ ParamSet = R6Class("ParamSet",

# set member variables
assertList(params, types = "ParamNode")
for (i in seq_along(params)) {
params[[i]]$handle$setRoot(self) # FIXME: are we sure? (p.s. members in handle will mainly be private in the future)
}
# for (i in seq_along(params)) {
# params[[i]]$handle$setRoot(self) # FIXME: are we sure? (p.s. members in handle will mainly be private in the future)
# }
self$params = params
self$trafo = assertFunction(trafo, args = c("x", "dict", "tags"), null.ok = TRUE)
self$restriction = assertClass(restriction, "call", null.ok = TRUE)
Expand Down
29 changes: 27 additions & 2 deletions R/ParamSetTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,35 @@ ParamSetTree = R6Class("ParamSetTree",
child.set = NULL,
context = NULL,

initialize = function(ns.id, ..., context = NULL) {
initialize = function(ns.id, ..., trafo = NULL, dictionary = NULL, tags = NULL, context = NULL) {
self$ns.id = assertNames(ns.id)
self$rt.hinge = ParamTreeFac(ns.id, ...)
self$context = context
# check function that checks the whole param set by simply iterating
check = function(x, na.ok = FALSE, null.ok = FALSE) {
assertSetEqual(names(x), names(self$params)) # self$params are list of ParamNodeSimple
if (is.data.table(x)) x = as.list(x)
res = checkList(x, names = "named")
if (!isTRUE(self$rt.hinge$visitor$checkValidFromFlat(x))) {
return(sprintf("Value Violation Found!"))
}
for (par.name in names(x)) {
res = self$params[[par.name]]$check(x[[par.name]], na.ok = na.ok, null.ok = null.ok)
if (!isTRUE(res)) return(res)
}
return(res)
}
super$initialize(ns.id, storage.type = "list", check = check, params = list(), dictionary = dictionary, tags = tags, restriction = NULL, trafo = trafo)
},

transform = function(x) {
x = ensureDataTable(x)
assertSetEqual(names(x), self$ids)
if (is.null(self$trafo))
return(x)
xs = self$trafo(x = x, dict = self$dictionary, tags = self$member.tags)
xs = ensureDataTable(xs)
return(xs)
},

# public methods
Expand Down Expand Up @@ -57,7 +82,7 @@ ParamSetTree = R6Class("ParamSetTree",


# This class is to extend the functionality of ParamSetTree and should not be exported!
ParamSetTreeX = R6Class("ParamSetTree",
ParamSetTreeX = R6Class("ParamSetTreeX",
inherit = ParamSetTree,
public = list(
sampleList = function(annotate = FALSE, sep = "_", recursive = FALSE) {
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test_ParamTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ test_that("test if ParamFac parse from flat", {
expect_true(TRUE)
})


test_that("test trafo works", {
pt = ParamSetTree$new("pt1",
ParamCategorical$new(id = "model", values = c("SVM", "RF")),
makeCondTreeNode(ParamReal$new(id = "C", lower = 0, upper = 100), depend = list(id = "model", fun = quote(model == "SVM"))),
makeCondTreeNode(ParamInt$new(id = "n_tree", lower = 1L, upper = 10L), depend = list(id = "model", fun = quote(model == "RF"))),
trafo = function(x, dict, tags) {
x$C = 2 * x$C
return(x)
}
)
expect_class(pt, "ParamSetTree")
})

test_that("test if two ParamTree works", {
pt = ParamSetTreeX$new("pt1",
ParamCategorical$new(id = "model", values = c("SVM", "RF")),
Expand Down

0 comments on commit 7f45285

Please sign in to comment.