Skip to content

Commit

Permalink
Fix bug in resampling when using predict = "train" (#1284) (#1315)
Browse files Browse the repository at this point in the history
* enable printing pred of resample results when predict type is train

* Finish bug fix and adapt test

* better tests
  • Loading branch information
MariaErdmann authored and larskotthoff committed Oct 31, 2016
1 parent 16ac057 commit 4dbec91
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
12 changes: 10 additions & 2 deletions R/ResamplePrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,22 @@ makeResamplePrediction = function(instance, preds.test, preds.train) {
rbindlist(lapply(seq_along(pr.tr), function(X) cbind(pr.tr[[X]]$data, iter = X, set = "train")))
))

p1 = preds.test[[1L]]
if (!any(tenull) && instance$desc$predict %in% c("test", "both")) {
p1 = preds.test[[1L]]
pall = preds.test
} else if (!any(trnull) && instance$desc$predict == "train") {
p1 = preds.train[[1L]]
pall = preds.train
}


makeS3Obj(c("ResamplePrediction", class(p1)),
instance = instance,
predict.type = p1$predict.type,
data = data,
threshold = p1$threshold,
task.desc = p1$task.desc,
time = extractSubList(preds.test, "time")
time = extractSubList(pall, "time")
)
}

Expand Down
12 changes: 11 additions & 1 deletion tests/testthat/test_base_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,24 @@ test_that("resampling, predicting train set works", {
m = setAggregation(mmce, train.mean)
r = resample(lrn, multiclass.task, rdesc, measures = m)
expect_false(is.na(r$aggr["mmce.train.mean"]))

expect_false(anyNA(r$pred$time))
expect_false(is.null(r$pred$predict.type))
expect_false(is.null(r$pred$threshold))
expect_equal(getTaskDescription(multiclass.task), r$pred$task.desc)

rdesc = makeResampleDesc("CV", iters = 2, predict = "both")
lrn = makeLearner("classif.rpart")
m1 = setAggregation(mmce, train.mean)
m2 = setAggregation(mmce, test.mean)
r = resample(lrn, multiclass.task, rdesc, measures = list(m1, m2))
expect_false(is.na(r$aggr["mmce.train.mean"]))
expect_false(is.na(r$aggr["mmce.test.mean"]))
expect_false(anyNA(r$pred$time))
expect_false(is.null(r$pred$predict.type))
expect_false(is.null(r$pred$threshold))
expect_equal(getTaskDescription(multiclass.task), r$pred$task.desc)



})

Expand Down

0 comments on commit 4dbec91

Please sign in to comment.