Skip to content

Commit

Permalink
check for sampled training task in resampling (fixes #1357) (#1362)
Browse files Browse the repository at this point in the history
* check for sampled training task in resampling (fixes #1357)

* fix test for sampling wrapper
  • Loading branch information
larskotthoff committed Dec 5, 2016
1 parent fa4258f commit d40a2a6
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions R/DownsampleWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ trainLearner.DownsampleWrapper = function(.learner, .task, .subset, .weights = N
.task = subsetTask(.task, .subset)
.task = downsample(.task, perc = dw.perc, stratify = dw.stratify)
m = train(.learner$next.learner, .task, weights = .task$weights)
m$train.task = .task
makeChainModel(next.model = m, cl = "DownsampleModel")
}
2 changes: 2 additions & 0 deletions R/OverUndersampleWrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ trainLearner.UndersampleWrapper = function(.learner, .task, .subset, .weights =
.task = subsetTask(.task, .subset)
.task = undersample(.task, rate = usw.rate, cl = usw.cl)
m = train(.learner$next.learner, .task, weights = .weights)
m$train.task = .task
makeChainModel(next.model = m, cl = "UndersampleModel")
}

Expand All @@ -83,6 +84,7 @@ trainLearner.OversampleWrapper = function(.learner, .task, .subset, .weights = N
.task = subsetTask(.task, .subset)
.task = oversample(.task, rate = osw.rate, cl = osw.cl)
m = train(.learner$next.learner, .task, weights = .weights)
m$train.task = .task
makeChainModel(next.model = m, cl = "OversampleModel")
}

4 changes: 1 addition & 3 deletions R/generateLearningCurve.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ generateLearningCurveData = function(learners, task, resampling = NULL,
else
assert(checkClass(resampling, "ResampleDesc"), checkClass(resampling, "ResampleInstance"))

perc.ids = seq_along(percs)

# create downsampled versions for all learners
lrnds1 = lapply(learners, function(lrn) {
lapply(perc.ids, function(p.id) {
lapply(seq_along(percs), function(p.id) {
perc = percs[p.id]
dsw = makeDownsampleWrapper(learner = lrn, dw.perc = perc, dw.stratify = stratify)
list(
Expand Down
18 changes: 16 additions & 2 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ doResampleIteration = function(learner, task, rin, i, measures, weights, model,
pred.train = NULL
pred.test = NULL
pp = rin$desc$predict
train.task = task
if (pp == "train") {
pred.train = predict(m, task, subset = train.i)
lm = getLearnerModel(m)
if ("BaseWrapper" %in% class(learner) && !is.null(lm$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = lm$train.task
train.i = lm$subset
}
pred.train = predict(m, train.task, subset = train.i)
if (!is.na(pred.train$error)) err.msgs[2L] = pred.train$error
ms.train = performance(task = task, model = m, pred = pred.train, measures = measures)
names(ms.train) = vcapply(measures, measureAggrName)
Expand All @@ -143,10 +150,17 @@ doResampleIteration = function(learner, task, rin, i, measures, weights, model,
ms.test = performance(task = task, model = m, pred = pred.test, measures = measures)
names(ms.test) = vcapply(measures, measureAggrName)
} else { # "both"
pred.train = predict(m, task, subset = train.i)
lm = getLearnerModel(m)
if ("BaseWrapper" %in% class(learner) && !is.null(lm$train.task)) {
# the learner was wrapped in a sampling wrapper
train.task = lm$train.task
train.i = lm$subset
}
pred.train = predict(m, train.task, subset = train.i)
if (!is.na(pred.train$error)) err.msgs[2L] = pred.train$error
ms.train = performance(task = task, model = m, pred = pred.train, measures = measures)
names(ms.train) = vcapply(measures, measureAggrName)

pred.test = predict(m, task, subset = test.i)
if (!is.na(pred.test$error)) err.msgs[2L] = paste(err.msgs[2L], pred.test$error)
ms.test = performance(task = task, model = m, pred = pred.test, measures = measures)
Expand Down
13 changes: 13 additions & 0 deletions tests/testthat/test_base_downsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,17 @@ test_that("downsample wrapper works with weights, we had issue #838", {
expect_true(length(u) == 5 && all(u %in% 1:10))
})

test_that("training performance works as expected (#1357)", {
num = makeMeasure(id = "num", minimize = FALSE,
properties = c("classif", "classif.multi", "req.pred", "req.truth"),
name = "Number",
fun = function(task, model, pred, feats, extra.args) {
length(pred$data$response)
}
)

rdesc = makeResampleDesc("Holdout", predict = "both")
lrn = makeDownsampleWrapper("classif.rpart", dw.perc = 0.1)
r = resample(lrn, multiclass.task, rdesc, measures = list(setAggregation(num, train.mean)))
expect_lte(r$measures.train$num, getTaskSize(multiclass.task) * 0.1)
})
22 changes: 22 additions & 0 deletions tests/testthat/test_base_imbal_overundersample.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,25 @@ test_that("control which class gets over or under sampled", {
r = resample(lrn2, binaryclass.task, rdesc)
expect_true(!is.na(r$aggr))
})

test_that("training performance works as expected (#1357)", {
num = makeMeasure(id = "num", minimize = FALSE,
properties = c("classif", "classif.multi", "req.pred", "req.truth"),
name = "Number",
fun = function(task, model, pred, feats, extra.args) {
length(pred$data$response)
}
)

y = binaryclass.df[, binaryclass.target]
z = getMinMaxClass(y)
rdesc = makeResampleDesc("Holdout", split = .5, predict = "both")

lrn = makeUndersampleWrapper("classif.rpart", usw.rate = 0.1, usw.cl = z$max.name)
r = resample(lrn, binaryclass.task, rdesc, measures = list(setAggregation(num, train.mean)))
expect_lt(r$measures.train$num, getTaskSize(binaryclass.task) * 0.5 - 1)

lrn = makeOversampleWrapper("classif.rpart", osw.rate = 2, osw.cl = z$max.name)
r = resample(lrn, binaryclass.task, rdesc, measures = list(setAggregation(num, train.mean)))
expect_gt(r$measures.train$num, getTaskSize(binaryclass.task) * 0.5 + 1)
})

0 comments on commit d40a2a6

Please sign in to comment.