Skip to content

Commit

Permalink
fix test for sampling wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Lars Kotthoff committed Dec 5, 2016
1 parent 6800a1b commit 69b59f1
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,11 @@ doResampleIteration = function(learner, task, rin, i, measures, weights, model,
pp = rin$desc$predict
train.task = task
if (pp == "train") {
if (!is.null(m$learner.model$next.model)) {
if (!is.null(m$learner.model$next.model$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = m$learner.model$next.model$train.task
train.i = m$learner.model$next.model$subset
}
lm = getLearnerModel(m)
if ("BaseWrapper" %in% class(learner) && !is.null(lm$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = lm$train.task
train.i = lm$subset
}
pred.train = predict(m, train.task, subset = train.i)
if (!is.na(pred.train$error)) err.msgs[2L] = pred.train$error
Expand All @@ -151,12 +150,11 @@ doResampleIteration = function(learner, task, rin, i, measures, weights, model,
ms.test = performance(task = task, model = m, pred = pred.test, measures = measures)
names(ms.test) = vcapply(measures, measureAggrName)
} else { # "both"
if (!is.null(m$learner.model$next.model)) {
if (!is.null(m$learner.model$next.model$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = m$learner.model$next.model$train.task
train.i = m$learner.model$next.model$subset
}
lm = getLearnerModel(m)
if ("BaseWrapper" %in% class(learner) && !is.null(lm$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = lm$train.task
train.i = lm$subset
}
pred.train = predict(m, train.task, subset = train.i)
if (!is.na(pred.train$error)) err.msgs[2L] = pred.train$error
Expand Down

0 comments on commit 69b59f1

Please sign in to comment.