Skip to content

Commit

Permalink
Error handling (#78)
Browse files Browse the repository at this point in the history
If there is a learner error, you have only two options now:

- get the exception (default) 
- define a fallback learner
  • Loading branch information
mllg committed Dec 5, 2018
1 parent ec33305 commit 9c18991
Show file tree
Hide file tree
Showing 18 changed files with 104 additions and 213 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^Meta$
^doc$
^tic\.R$
^appveyor\.yml$
.ignore
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Meta
doc
# History files
.Rhistory
.Rapp.history
Expand All @@ -21,6 +23,5 @@
*.utf8.md
*.knit.md
inst/doc

.DS_Store
.Rproj.user
35 changes: 18 additions & 17 deletions R/Experiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ Experiment = R6Class("Experiment",
self$data = named_list(mlr_reflections$experiment_slots$name)
self$data$task = assert_task(task)
self$data$learner = assert_learner(learner, task = task)
self$data$state = as_experiment_state("defined")
if (...length()) {
dots = list(...)
assert_names(names(dots), type = "unique", subset.of = names(self$data))
Expand Down Expand Up @@ -217,7 +216,7 @@ Experiment = R6Class("Experiment",
},

state = function() {
self$data$state
experiment_state(self)
},

hash = function() {
Expand Down Expand Up @@ -269,9 +268,8 @@ experiment_train = function(self, row_ids, ctrl = mlr_control()) {
debug("Running train_worker()")
value = train_worker(self, ctrl = ctrl)
}
experiment_set_state(self, "trained")
self$data = insert_named(self$data, value)
return(self)
return(experiment_reset_state(self, "trained"))
}

experiment_predict = function(self, row_ids = NULL, newdata = NULL, ctrl = mlr_control()) {
Expand All @@ -293,9 +291,8 @@ experiment_predict = function(self, row_ids = NULL, newdata = NULL, ctrl = mlr_c
debug("Running predict_worker()")
value = predict_worker(self, ctrl = ctrl)
}
experiment_set_state(self, "predicted")
self$data = insert_named(self$data, value)
return(self)
return(experiment_reset_state(self, "predicted"))
}

experiment_score = function(self, measures = NULL, ctrl = mlr_control()) {
Expand All @@ -310,7 +307,6 @@ experiment_score = function(self, measures = NULL, ctrl = mlr_control()) {
value = score_worker(self, ctrl = ctrl)
}

experiment_set_state(self, "scored")
self$data = insert_named(self$data, value)
return(self)
}
Expand All @@ -325,16 +321,21 @@ combine_experiments = function(x) {
})
}

experiment_set_state = function(self, new_state) {
new_state = as_experiment_state(new_state)
reset = mlr_reflections$experiment_slots[get("state") > new_state, "name", with = FALSE][[1L]]
self$data = insert_named(self$data, named_list(reset))
self$data$state = new_state
invisible(self)
experiment_state = function(self) {
as_state = function(state) ordered(state, levels = mlr_reflections$experiment_states)
d = self$data

if (!is.null(d$performance))
return(as_state("scored"))
if (!is.null(d$prediction))
return(as_state("predicted"))
if (!is.null(d$model))
return(as_state("trained"))
return(as_state("defined"))
}

as_experiment_state = function(state) {
states = levels(mlr_reflections$experiment_slots$state)
assert_choice(state, states)
ordered(state, levels = states)
experiment_reset_state = function(self, new_state) {
slots = mlr_reflections$experiment_slots[get("state") > new_state, "name", with = FALSE][[1L]]
self$data[slots] = list(NULL)
self
}
4 changes: 4 additions & 0 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
#' * `$properties` (`character()`) is a set of tags which describe the properties of the learner.
#' * `$train()` takes a task and returns a model fitted on all observations.
#' * `$predict()` takes a task and the model fitted in `$train()` to return predicted labels.
#' * `$fallback` stores the fallback learner which is used to generate predictions if this learner
#' fails to train or predict. This mechanism is disabled unless you explictly
#' assign a learner to this slot.
#' * `$hash` stores a checksum (`character(1)`) calculated on the `id` and `param_vals`.
#'
#' @name Learner
Expand All @@ -76,6 +79,7 @@ Learner = R6Class("Learner",
packages = NULL,
properties = NULL,
param_set = NULL,
fallback = NULL,

initialize = function(id, task_type, feature_types= character(0L), predict_types = character(0L), packages = character(0L), param_set = ParamSet$new(), param_vals = list(), properties = character(0L)) {
self$id = assert_id(id)
Expand Down
10 changes: 5 additions & 5 deletions R/Log.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
#' @name Log
#' @examples
#' # Create a simple experiment and extract the train log:
#' e = Experiment$new(
#' task = mlr_tasks$get("sonar"),
#' learner = mlr_learners$get("classif.crashtest")
#' )
#' e$train(ctrl = mlr_control(error_handling = "catch"))
#' task = mlr_tasks$get("sonar")
#' learner = mlr_learners$get("classif.crashtest")
#' learner$fallback = mlr_learners$get("classif.featureless")
#' e = Experiment$new(task, learner)
#' e$train(ctrl = mlr_control(use_evaluate = TRUE))
#' log = e$logs$train
#'
#' log$has_condition("error")
Expand Down
1 change: 0 additions & 1 deletion R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ benchmark = function(tasks, learners, resamplings, measures = NULL, ctrl = mlr_c

res = data.table(task = tasks[grid$task], learner = learners[grid$learner], resampling = instances[grid$instance], measures = measures[grid$task], hash = grid$hash)
ref_cbind(res, combine_experiments(tmp))
res$state = as_experiment_state("scored")

BenchmarkResult$new(res)
}
23 changes: 3 additions & 20 deletions R/mlr_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,7 @@
#' * `store_model`: If `FALSE`, the model returned by the learner is discarded in order to save some memory after the experiment is completed.
#' Note that you will be unable to further predict on new data.
#' * `store_prediction`: If `FALSE`, the predictions are discarded in order to save some memory after the experiment is completed.
#' Note that you will be unable calculate more performance measures.
#' * `error_handling`: How to deal with models raising exceptions during `train` or `predict`?
#' - `"off"` (default). An exception is raised, stopping the execution.
#' - `"catch"`. Exceptions are caught and logged. There will be no predictions available, and the performance will be `NA`.
#' All output is stored in the [Experiment] as a [Log].
#' - `"impute_worst"`. This is similar to the `"catch"` approach, but instead of predicting `NA`, the worst
#' possible performance is predicted.
#' - `"fallback_train"`. If the learner fails to fit a model during train, fit a fallback model, e.g. with a featureless learner.
#' The fallback learner is in this case used to generate predictions which are then scored.
#' Note that this mechanism does not guard you from models which successfully train, but raise exceptions during predict.
#' This would result in missing predictions and `NA` scores.
#' - `"fallback"`. Always fit a fallback model and use it if the learner fails to train or predict.
#' * `fallback_learner`: If `"error_handling"` is set to `"fallback_train"` or `"fallback"`, use this learner as fallback learner.
#' * `use_evaluate`: Capture output via \pkg{evaluate} and store it as log.
#'
#' @param ... Named arguments to overwrite the defaults / options.
#'
Expand Down Expand Up @@ -55,11 +43,6 @@ use_future = function(ctrl = NULL) {
isTRUE(opt) && requireNamespace("future", quietly = TRUE) && requireNamespace("future.apply", quietly = TRUE)
}

use_evaluate = function(ctrl = NULL) {
opt = if (is.null(ctrl)) getOption("mlr3.error_handling") else ctrl$error_handling
assert_choice(opt, c("off", "catch", "impute_worst", "fallback_train", "fallback"), .var.name = "Option 'error_handling'")
if (opt == "off")
return(FALSE)
require_namespaces("evaluate")
return(TRUE)
use_evaluate = function(ctrl) {
ctrl$use_evaluate && requireNamespace("evaluate", quietly = TRUE)
}
15 changes: 8 additions & 7 deletions R/reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ mlr_reflections$predict_types = list(
regr = c("response", "se")
)

mlr_reflections$experiment_states = c("defined", "trained", "predicted", "scored")

mlr_reflections$experiment_slots = data.table(
name = c("state", "task", "learner", "resampling", "iteration", "model", "fallback", "train_log", "train_time", "predict_log", "predict_time", "prediction", "measures", "performance", "score_time"),
type = c("ordered", "Task", "Learner", "Resampling", "integer", NA_character_, "Learner", "data.table", "numeric", "data.table", "numeric", "data.table", "list", "list", "numeric"),
atomic = c(TRUE, FALSE, FALSE, FALSE, TRUE, FALSE, FALSE, FALSE, TRUE, FALSE, TRUE, FALSE, FALSE, FALSE, TRUE),
state = c("defined", "defined", "defined", "trained", "trained", "trained", "trained", "trained", "trained", "predicted", "predicted", "predicted", "scored", "scored", "scored")
name = c("task", "learner", "resampling", "iteration", "model", "train_log", "train_time", "predict_log", "predict_time", "prediction", "measures", "performance", "score_time"),
type = c("Task", "Learner", "Resampling", "integer", NA_character_, "data.table", "numeric", "data.table", "numeric", "data.table", "list", "list", "numeric"),
atomic = c(FALSE, FALSE, FALSE, TRUE, FALSE, FALSE, TRUE, FALSE, TRUE, FALSE, FALSE, FALSE, TRUE),
state = c("defined", "defined", "trained", "trained", "trained", "trained", "trained", "predicted", "predicted", "predicted", "scored", "scored", "scored")
)

mlr_reflections$experiment_slots$state = ordered(mlr_reflections$experiment_slots$state, levels = c("defined", "trained", "predicted", "scored"))
mlr_reflections$experiment_slots$state = ordered(mlr_reflections$experiment_slots$state, levels = mlr_reflections$experiment_states)

mlr_reflections$log_classes = c("output", "message", "warning", "error")

Expand All @@ -62,6 +64,5 @@ mlr_reflections$default_mlr_options = list(
mlr_reflections$default_mlr_control = list(
store_model = TRUE,
store_prediction = TRUE,
error_handling = "off",
fallback_learner = NULL
use_evaluate = FALSE
)
4 changes: 2 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ resample = function(task, learner, resampling, measures = NULL, ctrl = mlr_contr
}

res = combine_experiments(res)
res[, c("state", "task", "learner", "resampling", "measures") :=
list(as_experiment_state("scored"), list(task), list(learner), list(instance), list(measures))]
res[, c("task", "learner", "resampling", "measures") :=
list(list(task), list(learner), list(instance), list(measures))]
ResampleResult$new(res)
}
28 changes: 8 additions & 20 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,27 @@ train_worker = function(e, ctrl) {
res = set_names(ecall(learner$train, pars, ctrl),
c("model", "train_log", "train_time"))

if (ctrl$error_handling == "fallback" || (ctrl$error_handling == "fallback_train" && res$train_log$has_condition("error"))) {
fb = assert_learner(ctrl$fallback_learner)
if (!is.null(learner$fallback)) {
fb = assert_learner(learner$fallback)
message(sprintf("Training fallback learner '%s' on task '%s' ...", fb$id, task$id))
require_namespaces(fb$packages, sprintf("The following packages are required for fallback learner %s: %%s", learner$id))
fb_model = try(fb$train(task))
if (inherits(fb_model, "try-error"))
stopf("Fallback learner '%s' failed during train", fb$id)
res$fallback = list(
learner = fb,
model = fb_model
)
res$model = fb_model
}

res
}

predict_worker = function(e, ctrl) {
data = e$data
if (is.null(data$model)) {
if (is.null(data$fallback))
return(list(predict_time = NA_real_))
data$learner = data$fallback$learner
data$model = data$fallback$model
}
learner = data$learner
if (data$train_log$has_condition("error")) {
if (is.null(learner$fallback))
stop(sprintf("Unable to predict learner '%s' without model", learner$id))
learner = learner$fallback
}
require_namespaces(learner$packages, sprintf("The following packages are required for learner %s: %%s", learner$id))

task = data$task$clone(deep = TRUE)$filter(e$test_set)
Expand All @@ -77,14 +73,6 @@ predict_worker = function(e, ctrl) {
score_worker = function(e, ctrl) {
data = e$data
measures = data$measures
if (is.null(data$prediction)) {
perf = if (ctrl$error_handling == "impute_worst") {
map(measures, function(x) x$range[2L])
} else {
replicate(length(measures), NA_real_, simplify = FALSE)
}
return(list(score_time = NA_real_, performance = set_names(perf, ids(measures))))
}
require_namespaces(unlist(lapply(measures, "[[", "packages")), "The following packages are required for the measures: %s")

if (ctrl$verbose)
Expand Down
21 changes: 9 additions & 12 deletions inst/testthat/helper_expectations.R
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ expect_experiment = function(e) {
testthat::expect_output(print(e), "Experiment")
state = e$state
checkmate::expect_factor(state, ordered = TRUE)
checkmate::expect_subset(as.character(state), levels(mlr3::mlr_reflections$experiment_slots$state))
checkmate::expect_subset(as.character(state), mlr3::mlr_reflections$experiment_states)
checkmate::expect_list(e$data, len = nrow(mlr3::mlr_reflections$experiment_slots))
checkmate::expect_names(names(e$data), permutation.of = mlr3::mlr_reflections$experiment_slots$name)

Expand All @@ -312,17 +312,15 @@ expect_experiment = function(e) {
}

if (state >= "predicted") {
checkmate::expect_class(e$data$predict_log, "Log", null.ok = TRUE)
checkmate::expect_number(e$data$predict_time, na.ok = e$has_errors)
if (!is.null(e$data$prediction)) { # may be null, depending on options
checkmate::expect_class(e$data$prediction, "Prediction")
checkmate::expect_atomic_vector(e$data$prediction$response, len = length(e$test_set), any.missing = FALSE)
}
checkmate::expect_class(e$data$predict_log, "Log")
checkmate::expect_number(e$data$predict_time)
checkmate::expect_class(e$data$prediction, "Prediction")
checkmate::expect_atomic_vector(e$data$prediction$response, len = length(e$test_set), any.missing = FALSE)
}

if (state >= "scored") {
checkmate::expect_list(e$data$performance, names = "unique")
checkmate::qassertr(e$data$performance, "n1")
checkmate::qassertr(e$data$performance, "N1")
}
}

Expand All @@ -334,7 +332,7 @@ expect_resample_result = function(rr) {
expect_resampling(rr$resampling, task = rr$task)

data = rr$data
checkmate::expect_data_table(rr$data, nrow = rr$resampling$iters, min.cols = nrow(mlr3::mlr_reflections$experiment_slots), any.missing = TRUE)
checkmate::expect_data_table(rr$data, nrow = rr$resampling$iters, min.cols = nrow(mlr3::mlr_reflections$experiment_slots), any.missing = FALSE)
checkmate::expect_names(names(rr$data), must.include = mlr3::mlr_reflections$experiment_slots$name)
expect_hash(rr$hash, 1L)

Expand All @@ -344,11 +342,10 @@ expect_resample_result = function(rr) {

measures = rr$measures$measure
aggr = rr$aggregated
errors = any(rr$errors)
for (m in measures) {
y = rr$performance(m$id)
checkmate::expect_numeric(y, lower = m$range[1], upper = m$range[2], any.missing = errors, label = sprintf("measure %s", m$id))
checkmate::expect_number(aggr[[m$id]], na.ok = errors, lower = m$range[1L], upper = m$range[2L], label = sprintf("measure %s", m$id))
checkmate::expect_numeric(y, lower = m$range[1], upper = m$range[2], any.missing = FALSE, label = sprintf("measure %s", m$id))
checkmate::expect_number(aggr[[m$id]], lower = m$range[1L], upper = m$range[2L], label = sprintf("measure %s", m$id))
}
}

Expand Down
3 changes: 3 additions & 0 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/Log.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 1 addition & 15 deletions man/mlr_control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_reflections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_evaluate.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
context("evaluate")

is_empty_log = function(log) { test_data_table(log$messages, nrow = 0L, ncol = 2L) && test_factor(log$messages$class, levels = mlr_reflections$log_classes) }
disabled = mlr_control(use_future = FALSE, error_handling = "off")
enabled = mlr_control(use_future = FALSE, error_handling = "catch", verbose = FALSE)
disabled = mlr_control(use_future = FALSE, use_evaluate = FALSE)
enabled = mlr_control(use_future = FALSE, use_evaluate = TRUE, verbose = FALSE)
task = mlr_tasks$get("iris")
learner = get_verbose_learner()
learner$param_vals = list(message = TRUE, warning = TRUE)
Expand Down
Loading

0 comments on commit 9c18991

Please sign in to comment.