Skip to content

Commit

Permalink
docs: better warning when mlr3 versions mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 3, 2024
1 parent 3c5f9c6 commit 5eaf280
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ Suggests:
testthat
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3batchmark (development version)

* docs: A warning is now given when the loaded mlr3 version differs from the
mlr3 version stored in the trained learners

# mlr3batchmark 0.1.1

* feat: `mlr3batchmark` now depends on package `batchtools` to avoid having to load `batchtools` explicitly.
Expand Down
12 changes: 12 additions & 0 deletions R/reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch
tabs = split(tabs, by = "job.name")
bmr = mlr3::BenchmarkResult$new()

no_warning_given = TRUE

for (tab in tabs) {
job = batchtools::makeJob(tab$job.id[1L], reg = reg)
bmr_tasks = bmr$tasks
Expand Down Expand Up @@ -60,6 +62,16 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch
}

results = batchtools::reduceResultsList(tab$job.id, reg = reg)

if (no_warning_given & length(results) && mlr_reflections$package_version != results[[1]]$learner_state$mlr3_version) {
lg$warn(paste(sep = "\n",
"The mlr3 version (%s) from one of the trained learners differs from the currently loaded mlr3 version (%s).",
"This can lead to unexpected behavior and we recommend installing the package versions used during experiment exectution."),
results[[1]]$learner_state$mlr3_version, mlr_reflections$package_version)

no_warning_given = FALSE
}

rdata = mlr3::ResultData$new(data.table(
task = list(task),
learner = list(learner),
Expand Down
5 changes: 5 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
#' @rawNamespace import(batchtools, except = chunk)
#' @importFrom uuid UUIDgenerate
"_PACKAGE"


.onLoad = function(libname, pkgname) {
assign("lg", lgr::get_logger(pkgname), envir = parent.env(environment()))
}
18 changes: 18 additions & 0 deletions tests/testthat/test_reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,21 @@ test_that("reduceResultsBatchmark", {
expect_data_table(tab, nrow = 4)
expect_set_equal(tab$resampling_id, ids(resamplings))
})

test_that("warning is given when mlr3 versions mismatch", {
test_version_warning = function() {
mlr3_version = mlr_reflections$package_version
reg = makeExperimentRegistry(NA)
batchmark(benchmark_grid(tsk("mtcars"), lrns(c("regr.rpart", "regr.featureless")), rsmp("holdout")))
submitJobs()
waitForJobs()

on.exit({mlr_reflections$package_version = mlr3_version}, add = TRUE)
mlr_reflections$package_version = "100.0.0"

capture.output(reduceResultsBatchmark(reg = reg))
expect_true(grepl("The mlr3 version", lg$last_event$msg, fixed = TRUE))
}

test_version_warning()
})

0 comments on commit 5eaf280

Please sign in to comment.