Skip to content

Commit

Permalink
feat: support callbacks in AcqOptimizer (#153)
Browse files Browse the repository at this point in the history
* feat: support callbacks in AcqOptimizer
  • Loading branch information
sumny committed Aug 13, 2024
1 parent 2d99077 commit d49b117
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 7 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Suggests:
rpart,
stringi,
testthat (>= 3.0.0)
Remotes: mlr-org/bbotk
ByteCompile: no
Encoding: UTF-8
Config/testthat/edition: 3
Expand Down
9 changes: 7 additions & 2 deletions R/AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,21 @@ AcqOptimizer = R6Class("AcqOptimizer",
#' @field acq_function ([AcqFunction]).
acq_function = NULL,

#' @field callbacks (`NULL` | list of [mlr3misc::Callback]).
callbacks = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param optimizer ([bbotk::Optimizer]).
#' @param terminator ([bbotk::Terminator]).
#' @param acq_function (`NULL` | [AcqFunction]).
initialize = function(optimizer, terminator, acq_function = NULL) {
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
initialize = function(optimizer, terminator, acq_function = NULL, callbacks = NULL) {
self$optimizer = assert_r6(optimizer, "Optimizer")
self$terminator = assert_r6(terminator, "Terminator")
self$acq_function = assert_r6(acq_function, "AcqFunction", null.ok = TRUE)
self$callbacks = assert_callbacks(as_callbacks(callbacks))
ps = ps(
n_candidates = p_int(lower = 1, default = 1L),
logging_level = p_fct(levels = c("fatal", "error", "warn", "info", "debug", "trace"), default = "warn"),
Expand Down Expand Up @@ -146,7 +151,7 @@ AcqOptimizer = R6Class("AcqOptimizer",
logger$set_threshold(self$param_set$values$logging_level)
on.exit(logger$set_threshold(old_threshold))

instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE)
instance = OptimInstanceBatchSingleCrit$new(objective = self$acq_function, search_space = self$acq_function$domain, terminator = self$terminator, check_values = FALSE, callbacks = self$callbacks)

# warmstart
if (self$param_set$values$warmstart) {
Expand Down
8 changes: 5 additions & 3 deletions R/sugar.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' @param cols_y (`NULL` | `character()`)\cr
#' Column id(s) in the [bbotk::Archive] that should be used as a target.
#' If a list of [mlr3::LearnerRegr] is provided as the `learner` argument and `cols_y` is
#' specified as well, as many column names as learners must be provided.
#' specified as well, as many column names as learners must be provided.
#' Can also be `NULL` in which case this is automatically inferred based on the archive.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
Expand Down Expand Up @@ -90,6 +90,8 @@ acqf = function(.key, ...) {
#' @param acq_function (`NULL` | [AcqFunction])\cr
#' [AcqFunction] that is to be used.
#' Can also be `NULL`.
#' @param callbacks (`NULL` | list of [mlr3misc::Callback])
#' Callbacks used during acquisition function optimization.
#' @param ... (named `list()`)\cr
#' Named arguments passed to the constructor, to be set as parameters in the
#' [paradox::ParamSet].
Expand All @@ -101,9 +103,9 @@ acqf = function(.key, ...) {
#' library(bbotk)
#' acqo(opt("random_search"), trm("evals"), catch_errors = FALSE)
#' @export
acqo = function(optimizer, terminator, acq_function = NULL, ...) {
acqo = function(optimizer, terminator, acq_function = NULL, callbacks = NULL, ...) {
dots = list(...)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function)
acqopt = AcqOptimizer$new(optimizer = optimizer, terminator = terminator, acq_function = acq_function, callbacks = callbacks)
acqopt$param_set$values = insert_named(acqopt$param_set$values, dots)
acqopt
}
Expand Down
6 changes: 5 additions & 1 deletion man/AcqOptimizer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/acqo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_AcqOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,32 @@ test_that("AcqOptimizer deep clone", {
expect_true(address(acqopt1$terminator) != address(acqopt2$terminator))
})

test_that("AcqOptimizer callbacks", {
domain = ps(x = p_dbl(lower = 10, upper = 20, trafo = function(x) x - 15))
objective = ObjectiveRFunDt$new(
fun = function(xdt) data.table(y = xdt$x ^ 2),
domain = domain,
codomain = ps(y = p_dbl(tags = "minimize")),
check_values = FALSE
)
instance = MAKE_INST(objective = objective, search_space = domain, terminator = trm("evals", n_evals = 5L))
design = MAKE_DESIGN(instance)
instance$eval_batch(design)
callback = callback_batch("mlr3mbo.acqopt_time",
on_optimization_begin = function(callback, context) {
callback$state$begin = Sys.time()
},
on_optimization_end = function(callback, context) {
callback$state$end = Sys.time()
attr(callback$state$outer_instance, "acq_opt_runtime") = as.numeric(callback$state$end - callback$state$begin)
}
)
callback$state$outer_instance = instance
acqfun = AcqFunctionEI$new(SurrogateLearner$new(REGR_FEATURELESS, archive = instance$archive))
acqopt = AcqOptimizer$new(opt("random_search", batch_size = 10L), trm("evals", n_evals = 10L), acq_function = acqfun, callbacks = callback)
acqfun$surrogate$update()
acqfun$update()
res = acqopt$optimize()
expect_number(attr(instance, "acq_opt_runtime"))
})

0 comments on commit d49b117

Please sign in to comment.