Skip to content

Commit

Permalink
Merge branch 'main' of github.com:mlr-org/mlr3mbo
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Aug 13, 2024
2 parents 4d2dd5d + d49b117 commit 5db34a1
Show file tree
Hide file tree
Showing 16 changed files with 112 additions and 18 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: mlr3mbo
Title: Flexible Bayesian Optimization
Version: 0.2.3.9000
Version: 0.2.4.9000
Authors@R: c(
person("Lennart", "Schneider", , "lennart.sch@web.de", role = c("cre", "aut"),
comment = c(ORCID = "0000-0003-4152-5308")),
Expand Down Expand Up @@ -66,13 +66,14 @@ Suggests:
rpart,
stringi,
testthat (>= 3.0.0)
Remotes: mlr-org/bbotk
ByteCompile: no
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: yes
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Collate:
'mlr_acqfunctions.R'
'AcqFunction.R'
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# mlr3mbo (development version)

# mlr3mbo 0.2.4

* fix: Improve runtime of `AcqOptimizer` by setting `check_values = FALSE`.

# mlr3mbo 0.2.3

* compatibility: Work with new bbotk and mlr3tuning version 1.0.0.
Expand Down
4 changes: 2 additions & 2 deletions R/AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ AcqFunctionCB = R6Class("AcqFunctionCB",
constants = list(...)
lambda = constants$lambda
p = self$surrogate$predict(xdt)
res = p$mean - self$surrogate_max_to_min * lambda * p$se
data.table(acq_cb = res)
cb = p$mean - self$surrogate_max_to_min * lambda * p$se
data.table(acq_cb = cb)
}
)
)
Expand Down
23 changes: 19 additions & 4 deletions R/AcqFunctionEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
#' @description
#' Expected Improvement.
#'
#' @section Parameters:
#' * `"epsilon"` (`numeric(1)`)\cr
#' \eqn{\epsilon} value used to determine the amount of exploration.
#' Higher values result in the importance of improvements predicted by the posterior mean
#' decreasing relative to the importance of potential improvements in regions of high predictive uncertainty.
#' Defaults to `0` (standard Expected Improvement).
#'
#' @references
#' * `r format_bib("jones_1998")`
#'
Expand Down Expand Up @@ -60,9 +67,15 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param surrogate (`NULL` | [SurrogateLearner]).
initialize = function(surrogate = NULL) {
#' @param epsilon (`numeric(1)`).
initialize = function(surrogate = NULL, epsilon = 0) {
assert_r6(surrogate, "SurrogateLearner", null.ok = TRUE)
super$initialize("acq_ei", surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
assert_number(epsilon, lower = 0, finite = TRUE)

constants = ps(epsilon = p_dbl(lower = 0, default = 0))
constants$values$epsilon = epsilon

super$initialize("acq_ei", constants = constants, surrogate = surrogate, requires_predict_type_se = TRUE, direction = "maximize", label = "Expected Improvement", man = "mlr3mbo::mlr_acqfunctions_ei")
},

#' @description
Expand All @@ -73,14 +86,16 @@ AcqFunctionEI = R6Class("AcqFunctionEI",
),

private = list(
.fun = function(xdt) {
.fun = function(xdt, ...) {
if (is.null(self$y_best)) {
stop("$y_best is not set. Missed to call $update()?")
}
constants = list(...)
epsilon = constants$epsilon
p = self$surrogate$predict(xdt)
mu = p$mean
se = p$se
d = self$y_best - self$surrogate_max_to_min * mu
d = (self$y_best - self$surrogate_max_to_min * mu) - epsilon
d_norm = d / se
ei = d * pnorm(d_norm) + se * dnorm(d_norm)
ei = ifelse(se < 1e-20, 0, ei)
Expand Down
9 changes: 7 additions & 2 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer",
#' @field acq_function ([AcqFunction]).
acq_function = NULL,

#' @field callbacks (`NULL` | list of [mlr3misc::Callback]).
callbacks = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param optimizer ([bbotk::Optimizer]).
#' @param terminator ([bbotk::Terminator]).
#' @param acq_function (`NULL` | [AcqFunction]).
initialize = function(optimizer, terminator, acq_function = NULL) {
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) {
self$optimizer = assert_r6(optimizer, "Optimizer")
self$terminator = assert_r6(terminator, "Terminator")
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
self$callbacks = assert_callbacks(as_callbacks(callbacks))
ps = ps(
n_candidates = p_int(lower = 1, default = 1L),
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
Expand Down Expand Up @@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
logger$set_threshold(self$param_set$values$logging_level)
on.exit(logger$set_threshold(old_threshold))

instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)

# warmstart
if (self$param_set$values$warmstart) {
Expand Down
8 changes: 5 additions & 3 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' @param cols_y (`NULL` | `character()`)\cr
#' Column id(s) in the [bbotk::Archive] that should be used as a target.
#' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is
#' specified as well, as many column names as learners must be provided.
#' specified as well, as many column names as learners must be provided.
#' Can also be `NULL` in which case this is automatically inferred based on the archive.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
Expand Down Expand Up @@ -90,6 +90,8 @@ acqf = function(.key, ...) {
#' @param acq_function (`NULL` | [AcqFunction])\cr
#' [AcqFunction] that is to be used.
#' Can also be `NULL`.
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
#' Callbacks used during acquisition function optimization.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
#' [paradox::ParamSet].
Expand All @@ -101,9 +103,9 @@ acqf = function(.key, ...) {
#' library(bbotk)
#' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE)
#' @export
acqo = function(optimizer, terminator, acq_function = NULL, ...) {
acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) {
dots = list(...)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks)
acqopt$param_set$values = insert_named(acqopt$param_set$values, dots)
acqopt
}
Expand Down
6 changes: 5 additions & 1 deletion man/AcqOptimizer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/acqo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion man/mlr_acqfunctions_ei.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ toc:

navbar:
structure:
left: [reference, news, book]
right: [github, mattermost, stackoverflow, rss, lightswitch]
left: [reference, intro, news, book]
right: [search, github, mattermost, stackoverflow, rss, lightswitch]
components:
home: ~
reference:
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ PS_1D_MIXED_DEPS = PS_1D_MIXED$clone(deep = TRUE)
PS_1D_MIXED_DEPS$add_dep("x2", on = "x4", cond = CondEqual$new(TRUE))

FUN_1D_MIXED = function(xs) {
if (is.null(xs$x2)) {
xs$x2 = "a"
}
list(y = (xs$x1 - switch(xs$x2, "a" = 0, "b" = 1, "c" = 2)) %% xs$x3 + (if (xs$x4) xs$x1 else pi))
}
OBJ_1D_MIXED = ObjectiveRFun$new(fun = FUN_1D_MIXED, domain = PS_1D_MIXED, properties = "single-crit")
OBJ_1D_MIXED_DEPS = ObjectiveRFun$new(fun = FUN_1D_MIXED, domain = PS_1D_MIXED_DEPS, properties = "single-crit")

FUN_1D_2_MIXED = function(xs) {
if (is.null(xs$x2)) {
xs$x2 = "a"
}
list(y1 = (xs$x1 - switch(xs$x2, "a" = 0, "b" = 1, "c" = 2)) %% xs$x3 + (if (xs$x4) xs$x1 else pi), y2 = xs$x1)
}
OBJ_1D_2_MIXED = ObjectiveRFun$new(fun = FUN_1D_2_MIXED, domain = PS_1D_MIXED, codomain = FUN_1D_2_CODOMAIN, properties = "multi-crit")
Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionCB.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ test_that("AcqFunctionCB works", {
expect_learner(acqf$surrogate$learner)
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), "lambda")

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionEHVIGH.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ test_that("AcqFunctionEHVIGH works", {
expect_true(acqf$requires_predict_type_se)
expect_setequal(acqf$packages, c("emoa", "fastGHQuad"))

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), c("k", "r"))

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionEI.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ test_that("AcqFunctionEI works", {
expect_learner(acqf$surrogate$learner)
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), "epsilon")

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test_AcqFunctionSmsEgo.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ test_that("AcqFunctionSmsEgo works", {
expect_list(acqf$surrogate$learner, types = "Learner")
expect_true(acqf$requires_predict_type_se)

expect_r6(acqf$constants, "ParamSet")
expect_equal(acqf$constants$ids(), c("lambda", "epsilon"))

design = MAKE_DESIGN(inst)
inst$eval_batch(design)

Expand Down
29 changes: 29 additions & 0 deletions tests/testthat/test_AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,32 @@ test_that("AcqOptimizer deep clone", {
expect_true(address(acqopt1$terminator) != address(acqopt2$terminator))
})

test_that("AcqOptimizer callbacks", {
domain = ps(x = p_dbl(lower = 10, upper = 20, trafo = function(x) x - 15))
objective = ObjectiveRFunDt$new(
fun = function(xdt) data.table(y = xdt$x ^ 2),
domain = domain,
codomain = ps(y = p_dbl(tags = "minimize")),
check_values = FALSE
)
instance = MAKE_INST(objective = objective, search_space = domain, terminator = trm("evals", n_evals = 5L))
design = MAKE_DESIGN(instance)
instance$eval_batch(design)
callback = callback_batch("mlr3mbo.acqopt_time",
on_optimization_begin = function(callback, context) {
callback$state$begin = Sys.time()
},
on_optimization_end = function(callback, context) {
callback$state$end = Sys.time()
attr(callback$state$outer_instance, "acq_opt_runtime") = as.numeric(callback$state$end - callback$state$begin)
}
)
callback$state$outer_instance = instance
acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_FEATURELESS, archive = instance$archive))
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 10L), trm("evals", n_evals = 10L), acq_function = acqfun, callbacks = callback)
acqfun$surrogate$update()
acqfun$update()
res = acqopt$optimize()
expect_number(attr(instance, "acq_opt_runtime"))
})

0 comments on commit 5db34a1

Please sign in to comment.