Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in TuneWrapper for param in used in predict function, issue 2472 #2479

Merged
merged 11 commits into from
Apr 15, 2019
12 changes: 10 additions & 2 deletions R/TuneWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@ trainLearner.TuneWrapper = function(.learner, .task, .subset = NULL, ...) {

#' @export
predictLearner.TuneWrapper = function(.learner, .model, .newdata, ...) {
lrn = setHyperPars(.learner$next.learner, par.vals = .model$learner.model$opt.result$x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I am relatively sure that removing this line (and not putting .learner = lrn further down etc) will break things.

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean .learner$next.learner? That moved to another line. But I added setHyperPars now again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2a0d9d0 is exactly what I meant, with .learner = lrn I meant the list entry ".learner" 👍

predictLearner(lrn, .model$learner.model$next.model, .newdata, ...)
# setHyperPars not used because par.vals are not acessed anymore
arglist = list(.learner = .learner$next.learner, .model = .model$learner.model$next.model, .newdata = .newdata)
arglist = insert(arglist, list(...))

# get x from opt result and only select those that are used for predition
opt.x = .model$learner.model$opt.result$x
ps = getParamSet(.learner)
ns = Filter(function(x) ps$pars[[x]]$when %in% c("both", "predict"), getParamIds(ps))
arglist = insert(arglist, opt.x[ns])
do.call(predictLearner, arglist)
}

#' @export
Expand Down
31 changes: 31 additions & 0 deletions tests/testthat/test_base_TuneWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,34 @@ test_that("TuneWrapper with glmnet (#958)", {
expect_error(pred, NA)
})

test_that("TuneWrapper respects train parameters (#2472)", {

# make task with only 0 as y
tsk = makeRegrTask("dummy", data = data.frame(y = rep(0L, 100), x = rep(1L, 100)), target = "y")

ps = makeParamSet(
makeNumericLearnerParam("p1", when = "train", lower = 0, upper = 10),
makeNumericLearnerParam("p2", when = "predict", lower = 0, upper = 10),
makeNumericLearnerParam("p3", when = "both", lower = 0, upper = 10)
)

lrn = makeLearner("regr.__mlrmocklearners__4", predict.type = "response", p1 = 10, p2 = 10, p3 = 10)
# prediction of this learner is always
# train_part = p1 + p3
# y = train_part + p2 + p3
# therefore p1 = p2 = p3 = 0 is the optimal setting
# we set params to bad values p1 = p2 = p3 = 10, meaning |y_hat-y| would be 40

lrn2 = makeTuneWrapper(lrn, resampling = makeResampleDesc("Holdout"),
par.set = ps,
control = makeTuneControlGrid(resolution = 2L))
mod = train(lrn2, tsk)
# we expect that the optimal parameters are found by the grid search.
expect_equal(mod$learner.model$opt.result$x, list(p1 = 0, p2 = 0, p3 = 0))
expect_true(mod$learner.model$opt.result$y == 0)
pred = predict(mod, tsk)
# we expect that the optimal parameter are also applied for prediction and therefore y_hat = p1+p2+p3+p3 should be 0
expect_true(all(pred$data$response == 0))
})