Skip to content

Commit

Permalink
refactor: internal tuning (#432)
Browse files Browse the repository at this point in the history
* fix: get best internal tuning

* refactor: internal tuning

* chore: news

* chore: remotes

* tests: first 20
  • Loading branch information
be-marc committed Jul 24, 2024
1 parent d6da69f commit 2dcd33d
Show file tree
Hide file tree
Showing 35 changed files with 513 additions and 302 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ Suggests:
xgboost
VignetteBuilder:
knitr
Remotes:
mlr-org/bbotk
Config/testthat/edition: 3
Config/testthat/parallel: false
Encoding: UTF-8
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3tuning (development version)

* refactor: Replace internal tuning callback.
* fix: Delete intermediate `BenchmarkResult` in `ObjectiveTuningBatch` after optimization.

# mlr3tuning 1.0.0
Expand Down
13 changes: 9 additions & 4 deletions R/ArchiveAsyncTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ ArchiveAsyncTuning = R6Class("ArchiveAsyncTuning",
rush,
internal_search_space = NULL
) {
init_internal_search_space_archive(self, private, super, search_space, internal_search_space)
if (!is.null(internal_search_space)) private$.internal_search_space = assert_param_set(internal_search_space)

super$initialize(
search_space = search_space,
codomain = codomain,
rush = rush)

private$.benchmark_result = BenchmarkResult$new()
},

Expand Down Expand Up @@ -183,7 +184,7 @@ ArchiveAsyncTuning = R6Class("ArchiveAsyncTuning",
)

#' @export
as.data.table.ArchiveAsyncTuning = function(x, ..., unnest = "x_domain", exclude_columns = NULL, measures = NULL) {
as.data.table.ArchiveAsyncTuning = function(x, ..., unnest = c("x_domain", "internal_tuned_values"), exclude_columns = NULL, measures = NULL) {
data = x$data_with_state()
if (!nrow(data)) return(data.table())

Expand All @@ -207,7 +208,11 @@ as.data.table.ArchiveAsyncTuning = function(x, ..., unnest = "x_domain", exclude
setdiff(x_domain_ids, exclude_columns)
}

setcolorder(tab, c(x$cols_x, if (length(x$internal_search_space$ids())) "internal_tuned_values", x$cols_y, cols_y_extra, cols_x_domain,
"runtime_learners", "timestamp_xs", "timestamp_ys"))
cols_internal_tuned_values = if ("internal_tuned_values" %in% cols) {
internal_tuned_values_ids = paste0("internal_tuned_values_", unique(unlist(map(x$data$internal_tuned_values, names))))
setdiff(internal_tuned_values_ids, exclude_columns)
}

setcolorder(tab, c(x$cols_x, x$cols_y, cols_y_extra, cols_internal_tuned_values, cols_x_domain, "runtime_learners", "timestamp_xs", "timestamp_ys"))
tab[, setdiff(names(tab), exclude_columns), with = FALSE]
}
13 changes: 8 additions & 5 deletions R/ArchiveBatchTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ ArchiveBatchTuning = R6Class("ArchiveBatchTuning",
check_values = FALSE,
internal_search_space = NULL
) {
if (!is.null(internal_search_space)) private$.internal_search_space = assert_param_set(internal_search_space)
super$initialize(search_space, codomain, check_values)

init_internal_search_space_archive(self, private, super, search_space, internal_search_space)

# initialize empty benchmark result
self$benchmark_result = BenchmarkResult$new()
},
Expand Down Expand Up @@ -181,7 +180,7 @@ ArchiveBatchTuning = R6Class("ArchiveBatchTuning",
)

#' @export
as.data.table.ArchiveBatchTuning = function(x, ..., unnest = "x_domain", exclude_columns = "uhash", measures = NULL) {
as.data.table.ArchiveBatchTuning = function(x, ..., unnest = c("x_domain", "internal_tuned_values"), exclude_columns = "uhash", measures = NULL) {
if (!nrow(x$data)) return(data.table())
data = copy(x$data)

Expand All @@ -208,7 +207,11 @@ as.data.table.ArchiveBatchTuning = function(x, ..., unnest = "x_domain", exclude
setdiff(x_domain_ids, exclude_columns)
}

setcolorder(tab, c(x$cols_x, if (length(x$internal_search_space$ids())) "internal_tuned_values", x$cols_y, cols_y_extra, cols_x_domain,
"runtime_learners", "timestamp", "batch_nr"))
cols_internal_tuned_values = if ("internal_tuned_values" %in% cols) {
internal_tuned_values_ids = paste0("internal_tuned_values_", unique(unlist(map(x$data$internal_tuned_values, names))))
setdiff(internal_tuned_values_ids, exclude_columns)
}

setcolorder(tab, c(x$cols_x, x$cols_y, cols_y_extra, cols_internal_tuned_values, cols_x_domain, "runtime_learners", "timestamp"))
tab[, setdiff(names(tab), exclude_columns), with = FALSE]
}
9 changes: 8 additions & 1 deletion R/ObjectiveTuning.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' @template param_check_values
#' @template param_store_benchmark_result
#' @template param_callbacks
#' @template param_internal_search_space
#'
#' @export
ObjectiveTuning = R6Class("ObjectiveTuning",
Expand Down Expand Up @@ -42,6 +43,10 @@ ObjectiveTuning = R6Class("ObjectiveTuning",
#' @field default_values (named `list()`).
default_values = NULL,

#' @field internal_search_space ([paradox::ParamSet]).
#' Internal search space for internal tuning.
internal_search_space = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(
Expand All @@ -52,7 +57,8 @@ ObjectiveTuning = R6Class("ObjectiveTuning",
store_benchmark_result = TRUE,
store_models = FALSE,
check_values = FALSE,
callbacks = NULL
callbacks = NULL,
internal_search_space = NULL
) {
self$task = assert_task(as_task(task, clone = TRUE))
self$learner = assert_learner(as_learner(learner, clone = TRUE))
Expand All @@ -61,6 +67,7 @@ ObjectiveTuning = R6Class("ObjectiveTuning",
self$store_models = assert_flag(store_models)
self$store_benchmark_result = assert_flag(store_benchmark_result) || self$store_models
self$callbacks = assert_callbacks(as_callbacks(callbacks))
self$internal_search_space = if (!is.null(internal_search_space)) assert_param_set(internal_search_space)

self$default_values = self$learner$param_set$values

Expand Down
8 changes: 8 additions & 0 deletions R/ObjectiveTuningAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,16 @@ ObjectiveTuningAsync = R6Class("ObjectiveTuningAsync",
warnings = sum(map_int(get_private(private$.resample_result)$.data$learner_states(), function(s) sum(s$log$class == "warning")))
errors = sum(map_int(get_private(private$.resample_result)$.data$learner_states(), function(s) sum(s$log$class == "error")))
runtime_learners = extract_runtime(private$.resample_result)

private$.aggregated_performance = c(private$.aggregated_performance, list(runtime_learners = runtime_learners, warnings = warnings, errors = errors))

# add internal tuned values
if (!is.null(self$internal_search_space)) {
lg$debug("Extracting internal tuned values")
internal_tuned_values = extract_inner_tuned_values(private$.resample_result, self$internal_search_space)
private$.aggregated_performance = c(private$.aggregated_performance, list(internal_tuned_values = list(internal_tuned_values)))
}

# add benchmark result and models
if (!self$store_models) {
lg$debug("Discarding models.")
Expand Down
17 changes: 15 additions & 2 deletions R/ObjectiveTuningBatch.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' @template param_check_values
#' @template param_store_benchmark_result
#' @template param_callbacks
#' @template param_internal_search_space
#'
#' @export
ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
Expand All @@ -36,7 +37,8 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
store_models = FALSE,
check_values = FALSE,
archive = NULL,
callbacks = NULL
callbacks = NULL,
internal_search_space = NULL
) {
self$archive = assert_r6(archive, "ArchiveBatchTuning", null.ok = TRUE)
if (is.null(self$archive)) store_benchmark_result = store_models = FALSE
Expand All @@ -49,7 +51,8 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks
callbacks = callbacks,
internal_search_space = internal_search_space
)
}
),
Expand Down Expand Up @@ -81,8 +84,18 @@ ObjectiveTuningBatch = R6Class("ObjectiveTuningBatch",
time = map_dbl(private$.benchmark_result$resample_results$resample_result, function(rr) {
extract_runtime(rr)
})

set(private$.aggregated_performance, j = "runtime_learners", value = time)

# add internal tuned values
if (!is.null(self$internal_search_space)) {
internal_tuned_values = map(private$.benchmark_result$resample_results$resample_result, function(resample_result) {
extract_inner_tuned_values(resample_result, self$internal_search_space)
})

set(private$.aggregated_performance, j = "internal_tuned_values", value = list(internal_tuned_values))
}

call_back("on_eval_before_archive", self$callbacks, self$context)

# store benchmark result in archive
Expand Down
2 changes: 1 addition & 1 deletion R/TunerAsyncFromOptimizerAsync.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TunerAsyncFromOptimizerAsync = R6Class("TunerAsyncFromOptimizerAsync",
#' @return [data.table::data.table].
optimize = function(inst) {
assert_tuning_instance_async(inst)
if (!inst$search_space$length && inst$internal_search_space$length) {
if (!inst$search_space$length && !is.null(inst$internal_search_space)) {
stopf("To only conduct internal parameter tuning, use tnr('internal')")
}
private$.optimizer$optimize(inst)
Expand Down
3 changes: 2 additions & 1 deletion R/TunerBatchFromBatchOptimizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ TunerBatchFromOptimizerBatch = R6Class("TunerBatchFromOptimizerBatch",
#' @return [data.table::data.table].
optimize = function(inst) {
assert_tuning_instance_batch(inst)
if (!inst$search_space$length && inst$internal_search_space$length && !test_class(self, "TunerBatchInternal")) {

if (!inst$search_space$length && !is.null(inst$internal_search_space) && !test_class(self, "TunerBatchInternal")) {
stopf("To only conduct internal parameter tuning, use tnr('internal')")
}
result = private$.optimizer$optimize(inst)
Expand Down
56 changes: 39 additions & 17 deletions R/TuningInstanceAsyncMulticrit.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
#'
#' @template param_xdt
#' @template param_learner_param_vals
#' @template param_internal_tuned_values
#'
#' @template field_internal_search_space
#'
#' @export
TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
inherit = OptimInstanceAsyncMultiCrit,
public = list(

internal_search_space = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(
Expand All @@ -48,6 +53,7 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
require_namespaces("rush")
learner = assert_learner(as_learner(learner, clone = TRUE))

# tune token and search space
if (!is.null(search_space) && length(learner$param_set$get_values(type = "only_token"))) {
stop("If the values of the ParamSet of the Learner contain TuneTokens you cannot supply a search_space.")
}
Expand All @@ -58,13 +64,20 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
search_space = as_search_space(search_space)
}

# modifies tuning instance in-place and adds the internal tuning callback
res = init_internal_search_space(self, private, super, search_space, store_benchmark_result, learner,
callbacks, batch = FALSE)
# internal search space
internal_tune_ids = keep(names(search_space$tags), map_lgl(search_space$tags, function(tag) "internal_tuning" %in% tag))
if (length(internal_tune_ids)) {
self$internal_search_space = search_space$subset(internal_tune_ids)

if (self$internal_search_space$has_trafo) {
stopf("Inner tuning and parameter transformations are currently not supported.")
}

private$.internal_search_space = res$internal_search_space
callbacks = res$callbacks
search_space = res$search_space
search_space = search_space$subset(setdiff(search_space$ids(), internal_tune_ids))

# the learner dictates how to interpret the to_tune(..., inner)
learner$param_set$set_values(.values = learner$param_set$convert_internal_search_space(self$internal_search_space))
}

if (is.null(rush)) rush = rush::rsh()

Expand All @@ -76,7 +89,7 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
search_space = search_space,
codomain = codomain,
rush = rush,
internal_search_space = private$.internal_search_space
internal_search_space = self$internal_search_space
)

objective = ObjectiveTuningAsync$new(
Expand All @@ -87,7 +100,8 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
store_benchmark_result = store_benchmark_result,
store_models = store_models,
check_values = check_values,
callbacks = callbacks)
callbacks = callbacks,
internal_search_space = self$internal_search_space)

super$initialize(
objective = objective,
Expand All @@ -104,7 +118,9 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
#'
#' @param ydt (`numeric(1)`)\cr
#' Optimal outcomes, e.g. the Pareto front.
assign_result = function(xdt, ydt, learner_param_vals = NULL) {
#' @param ... (`any`)\cr
#' ignored.
assign_result = function(xdt, ydt, learner_param_vals = NULL, ...) {
# set the column with the learner param_vals that were not optimized over but set implicitly
if (is.null(learner_param_vals)) {
learner_param_vals = self$objective$learner$param_set$values
Expand All @@ -115,7 +131,20 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
opt_x = transform_xdt_to_xss(xdt, self$search_space)
if (length(opt_x) == 0) opt_x = replicate(length(ydt), list())
learner_param_vals = Map(insert_named, learner_param_vals, opt_x)
xdt = cbind(xdt, learner_param_vals)

# disable internal tuning
if (!is.null(xdt$internal_tuned_values)) {
learner = self$objective$learner$clone(deep = TRUE)
learner_param_vals = pmap(list(learner_param_vals, xdt$internal_tuned_values), function(lpv, itv) {
values = insert_named(lpv, itv)
learner$param_set$set_values(.values = values, .insert = FALSE)
learner$param_set$disable_internal_tuning(self$internal_search_space$ids())
learner$param_set$values
})
}

set(xdt, j = "learner_param_vals", value = list(learner_param_vals))

super$assign_result(xdt, ydt)
}
),
Expand All @@ -126,17 +155,10 @@ TuningInstanceAsyncMultiCrit = R6Class("TuningInstanceAsyncMultiCrit",
#' List of param values for the optimal learner call.
result_learner_param_vals = function() {
private$.result$learner_param_vals
},
#' @field internal_search_space ([paradox::ParamSet])\cr
#' The search space containing those parameters that are internally optimized by the [`mlr3::Learner`].
internal_search_space = function(rhs) {
assert_ro_binding(rhs)
private$.internal_search_space
}
),

private = list(
.internal_search_space = NULL,

# initialize context for optimization
.initialize_context = function(optimizer) {
Expand Down
Loading

0 comments on commit 2dcd33d

Please sign in to comment.