Skip to content

Commit

Permalink
Make sure that optimized hyperparameters are applied in the performan…
Browse files Browse the repository at this point in the history
…ce level of a CV (#2479)
  • Loading branch information
berndbischl authored and pat-s committed Apr 15, 2019
1 parent fa6328e commit 7ea4a57
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
3 changes: 3 additions & 0 deletions NEWS.md
Expand Up @@ -58,6 +58,9 @@ In this case, the package name is omitted.
* regr.h2o.gbm: Various parameters added, `"h2o.use.data.table" = TRUE` is now the default (@j-hartshorn, #2508)
* h2o learners now support getting feature importance (@markusdumke, #2434)

## learners - fixes
* In some cases the optimized hyperparameters were not applied in the performance level of a nested CV (@berndbischl, #2479)

## featSel - general
* The FeatSelResult object now contains an additional slot `x.bit.names` that stores the optimal bits
* The slot `x` now always contains the real feature names and not the bit.names
Expand Down
11 changes: 10 additions & 1 deletion R/TuneWrapper.R
Expand Up @@ -74,8 +74,17 @@ trainLearner.TuneWrapper = function(.learner, .task, .subset = NULL, ...) {

#' @export
predictLearner.TuneWrapper = function(.learner, .model, .newdata, ...) {
# setHyperPars just set for completivnes, Actual hyperparams are in ...
lrn = setHyperPars(.learner$next.learner, par.vals = .model$learner.model$opt.result$x)
predictLearner(lrn, .model$learner.model$next.model, .newdata, ...)
arglist = list(.learner = lrn, .model = .model$learner.model$next.model, .newdata = .newdata)
arglist = insert(arglist, list(...))

# get x from opt result and only select those that are used for predition
opt.x = .model$learner.model$opt.result$x
ps = getParamSet(lrn)
ns = Filter(function(x) ps$pars[[x]]$when %in% c("both", "predict"), getParamIds(ps))
arglist = insert(arglist, opt.x[ns])
do.call(predictLearner, arglist)
}

#' @export
Expand Down
31 changes: 31 additions & 0 deletions tests/testthat/test_base_TuneWrapper.R
Expand Up @@ -120,3 +120,34 @@ test_that("TuneWrapper with glmnet (#958)", {
expect_error(pred, NA)
})

test_that("TuneWrapper respects train parameters (#2472)", {

# make task with only 0 as y
tsk = makeRegrTask("dummy", data = data.frame(y = rep(0L, 100), x = rep(1L, 100)), target = "y")

ps = makeParamSet(
makeNumericLearnerParam("p1", when = "train", lower = 0, upper = 10),
makeNumericLearnerParam("p2", when = "predict", lower = 0, upper = 10),
makeNumericLearnerParam("p3", when = "both", lower = 0, upper = 10)
)

lrn = makeLearner("regr.__mlrmocklearners__4", predict.type = "response", p1 = 10, p2 = 10, p3 = 10)
# prediction of this learner is always
# train_part = p1 + p3
# y = train_part + p2 + p3
# therefore p1 = p2 = p3 = 0 is the optimal setting
# we set params to bad values p1 = p2 = p3 = 10, meaning |y_hat-y| would be 40

lrn2 = makeTuneWrapper(lrn, resampling = makeResampleDesc("Holdout"),
par.set = ps,
control = makeTuneControlGrid(resolution = 2L))
mod = train(lrn2, tsk)
# we expect that the optimal parameters are found by the grid search.
expect_equal(mod$learner.model$opt.result$x, list(p1 = 0, p2 = 0, p3 = 0))
expect_true(mod$learner.model$opt.result$y == 0)
pred = predict(mod, tsk)
# we expect that the optimal parameter are also applied for prediction and therefore y_hat = p1+p2+p3+p3 should be 0
expect_true(all(pred$data$response == 0))
})


0 comments on commit 7ea4a57

Please sign in to comment.