Skip to content

Commit

Permalink
adds test for print.ParamSet. Closes #1352 (#1419)
Browse files Browse the repository at this point in the history
* test for print.ParamSet closes #1352

* rm unwanted bin file

* adjust other tests in learners_all

* update function docu
  • Loading branch information
florianfendt authored and larskotthoff committed Jan 11, 2017
1 parent 67c25f4 commit 9534394
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
20 changes: 14 additions & 6 deletions tests/testthat/helper_learners_all.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ testThatLearnerRespectsWeights = function(lrn, task, train.inds, test.inds, weig
}


# Test that learner produces output on the console, can be trained, can predict
# and that a performance measure is calculated.
# Test that learner produces output on the console, its ParamSet can be printed,
# can be trained, can predict and that a performance measure is calculated.
# This function is being used to test learners in general and in the other
# helper functions testing learners that claim to handle missings, factors,...
# It also tests if the learner can predict probabilities or standard errors.
Expand All @@ -54,14 +54,22 @@ testThatLearnerRespectsWeights = function(lrn, task, train.inds, test.inds, weig
# can predict probabilities or specification "se" when testing learner can
# predict standard errors.)

testThatLearnerCanTrainPredict = function(lrn, task, hyperpars, pred.type = "response") {
testBasicLearnerProperties = function(lrn, task, hyperpars, pred.type = "response") {
# handling special par.vals and predict type
info = lrn$id
if (lrn$id %in% names(hyperpars))
lrn = setHyperPars(lrn, par.vals = hyperpars[[lrn$id]])

lrn = setPredictType(lrn, pred.type)

# check that learner prints
expect_output(info = info, print(lrn), lrn$id)

# check that param set prints
par.set = getParamSet(lrn)
expect_output(info = info, print(par.set))

# check that learner trains, predicts
m = train(lrn, task)
p = predict(m, task)
expect_true(info = info, !is.na(performance(pred = p, task = task)))
Expand Down Expand Up @@ -108,7 +116,7 @@ testThatLearnerHandlesFactors = function(lrn, task, hyperpars) {
d[,f] = as.factor(rep_len(c("a", "b"), length.out = nrow(d)))
new.task = changeData(task = task, data = d)

testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars)
testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars)
}


Expand All @@ -126,7 +134,7 @@ testThatLearnerHandlesOrderedFactors = function(lrn, task, hyperpars) {
d[,f] = as.ordered(rep_len(c("a", "b", "c"), length.out = nrow(d)))
new.task = changeData(task = task, data = d)

testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars)
testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars)

}

Expand All @@ -145,7 +153,7 @@ testThatLearnerHandlesMissings = function(lrn, task, hyperpars) {
d[1,f] = NA
new.task = changeData(task = task, data = d)

testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars)
testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars)
}

testThatLearnerCanCalculateImportance = function(lrn, task, hyperpars) {
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_learners_all_classif.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that("learners work: classif ", {
features = getTaskFeatureNames(binaryclass.task)[12:15])
lrns = mylist(task, create = TRUE)
lapply(lrns, testThatLearnerParamDefaultsAreInParamSet)
lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars)
lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars)

# binary classif with factors
lrns = mylist("classif", properties = "factors", create = TRUE)
Expand All @@ -37,7 +37,7 @@ test_that("learners work: classif ", {

# binary classif with prob
lrns = mylist(binaryclass.task, properties = "prob", create = TRUE)
lapply(lrns, testThatLearnerCanTrainPredict, task = binaryclass.task,
lapply(lrns, testBasicLearnerProperties, task = binaryclass.task,
hyperpars = hyperpars, pred.type = "prob")

# binary classif with weights
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_learners_all_clusters.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ test_that("learners work: cluster", {
task = noclass.task
lrns = mylist(task, create = TRUE)
lapply(lrns, testThatLearnerParamDefaultsAreInParamSet)
lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars)
lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars)

# clustering, prob
task = subsetTask(noclass.task, subset = 1:20, features = getTaskFeatureNames(noclass.task)[1:2])
lrns = mylist(task, properties = "prob", create = TRUE)
lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars,
lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars,
pred.type = "prob")

# cluster with weights
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_learners_all_multilabel.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ test_that("learners work: multilabel", {
# multiabel
lrns = mylist("multilabel", create = TRUE)
lapply(lrns, testThatLearnerParamDefaultsAreInParamSet)
lapply(lrns, testThatLearnerCanTrainPredict, task = multilabel.task, hyperpars = hyperpars)
lapply(lrns, testBasicLearnerProperties, task = multilabel.task, hyperpars = hyperpars)

# multilabel, probs
lrns = mylist("multilabel", properties = "prob", create = TRUE)
lapply(lrns, testThatLearnerCanTrainPredict, task = multilabel.task,
lapply(lrns, testBasicLearnerProperties, task = multilabel.task,
hyperpars = hyperpars, pred.type = "prob")

# multilabel, factors
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_learners_all_regr.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_that("learners work: regr ", {
# normal regr
lrns = mylist("regr", create = TRUE)
lapply(lrns, testThatLearnerParamDefaultsAreInParamSet)
lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars)
lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars)

# regr with factors
lrns = mylist("regr", properties = "factors", create = TRUE)
Expand All @@ -33,7 +33,7 @@ test_that("learners work: regr ", {

# regr with se
lrns = mylist(task, properties = "se", create = TRUE)
lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars,
lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars,
pred.type = "se")

# regr with weights
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_learners_all_surv.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ test_that("learners work: surv ", {
features = getTaskFeatureNames(surv.task)[c(1,2)])
lrns = mylist("surv", create = TRUE)
lapply(lrns, testThatLearnerParamDefaultsAreInParamSet)
lapply(lrns, testThatLearnerCanTrainPredict, task = sub.task, hyperpars = hyperpars)
lapply(lrns, testBasicLearnerProperties, task = sub.task, hyperpars = hyperpars)

# survival analysis with factors
lrns = mylist("surv", properties = "factors", create = TRUE)
Expand Down

0 comments on commit 9534394

Please sign in to comment.