Skip to content

Commit

Permalink
Merge ee64936 into cf87ffe
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Feb 13, 2019
2 parents cf87ffe + ee64936 commit 56d9b25
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
16 changes: 9 additions & 7 deletions R/ParamSetCollection.R
Expand Up @@ -49,6 +49,11 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
stopf("Setid '%s' already present in collection!", p$set_id)
if (p$has_trafo)
stop("Building a collection out sets, where a ParamSet has a trafo is currently unsupported!")
nameclashes = intersect(sprintf("%s.%s", p$set_id, names(p$params)), names(self$params))
if (length(nameclashes)) {
stopf("Adding parameter set would lead to nameclashes: %s", str_collapse(nameclashes))
}

private$.sets[[length(private$.sets) + 1L]] = p
},

Expand Down Expand Up @@ -111,15 +116,12 @@ ParamSetCollection = R6Class("ParamSetCollection", inherit = ParamSet,
} else {
assert_list(xs)
self$assert(xs) # make sure everything is valid and feasible
# extract everything before 1st dot
set_ids = sub("^([^.]+)\\..+", "\\1", names(xs))
xs = split(xs, set_ids) # partition xs into parts wrt to setids

for (s in private$.sets) {
# retrieve sublist for each set, then assign it in set (after removing prefix)
pv = xs[[s$set_id]]
if (is.null(pv))
pv = list()
names(pv) = sub(sprintf("^%s.", s$set_id), "", names(pv))
psids = sprintf("%s.%s", s$set_id, names(s$params))
pv = xs[intersect(psids, names(xs))]
names(pv) = substr(names(pv), nchar(s$set_id) + 2, nchar(names(pv)))
s$values = pv
}
}
Expand Down
30 changes: 30 additions & 0 deletions tests/testthat/test_ParamSetCollection.R
Expand Up @@ -207,3 +207,33 @@ test_that("collection allows state-change setting of paramvals, see issue 205",
expect_equal(psc$values, list(s1.d1 = 1, s2.d2 = 2))
})

test_that("set_id inference in values assignment works now", {

psa = ParamSet$new(list(ParamDbl$new("parama")))
psa$set_id = "a.b"

psb = ParamSet$new(list(ParamDbl$new("paramb")))
psb$set_id = "b"

psc = ParamSet$new(list(ParamDbl$new("paramc")))
psc$set_id = "c"

pscol1 = ParamSetCollection$new(list(psb, psc))
pscol1$set_id = "a"

pscol2 = ParamSetCollection$new(list(psa, pscol1))

pstest = ParamSet$new(list(ParamDbl$new("paramc")))
pstest$set_id = "a.c"

expect_error(pscol2$add(pstest), "nameclashes.* a\\.c\\.paramc")

pscol2$values = list(a.c.paramc = 3, a.b.parama = 1, a.b.paramb = 2)

expect_equal(psa$values, list(parama = 1))
expect_equal(psb$values, list(paramb = 2))
expect_equal(psc$values, list(paramc = 3))
expect_equal(pscol1$values, list(b.paramb = 2, c.paramc = 3))
expect_equal(pscol2$values, list(a.b.parama = 1, a.b.paramb = 2, a.c.paramc = 3))

})

0 comments on commit 56d9b25

Please sign in to comment.