From 9534394d8cda3ddb313d7648643cc62a1424ef77 Mon Sep 17 00:00:00 2001 From: Florian Fendt Date: Wed, 11 Jan 2017 18:30:29 +0100 Subject: [PATCH] adds test for print.ParamSet. Closes #1352 (#1419) * test for print.ParamSet closes #1352 * rm unwanted bin file * adjust other tests in learners_all * update function docu --- tests/testthat/helper_learners_all.R | 20 +++++++++++++------ tests/testthat/test_learners_all_classif.R | 4 ++-- tests/testthat/test_learners_all_clusters.R | 4 ++-- tests/testthat/test_learners_all_multilabel.R | 4 ++-- tests/testthat/test_learners_all_regr.R | 4 ++-- tests/testthat/test_learners_all_surv.R | 2 +- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/testthat/helper_learners_all.R b/tests/testthat/helper_learners_all.R index b1b5a6a6bf..0cba6aa264 100644 --- a/tests/testthat/helper_learners_all.R +++ b/tests/testthat/helper_learners_all.R @@ -39,8 +39,8 @@ testThatLearnerRespectsWeights = function(lrn, task, train.inds, test.inds, weig } -# Test that learner produces output on the console, can be trained, can predict -# and that a performance measure is calculated. +# Test that learner produces output on the console, its ParamSet can be printed, +# can be trained, can predict and that a performance measure is calculated. # This function is being used to test learners in general and in the other # helper functions testing learners that claim to handle missings, factors,... # It also tests if the learner can predict probabilities or standard errors. @@ -54,14 +54,22 @@ testThatLearnerRespectsWeights = function(lrn, task, train.inds, test.inds, weig # can predict probabilities or specification "se" when testing learner can # predict standard errors.) -testThatLearnerCanTrainPredict = function(lrn, task, hyperpars, pred.type = "response") { +testBasicLearnerProperties = function(lrn, task, hyperpars, pred.type = "response") { + # handling special par.vals and predict type info = lrn$id if (lrn$id %in% names(hyperpars)) lrn = setHyperPars(lrn, par.vals = hyperpars[[lrn$id]]) lrn = setPredictType(lrn, pred.type) + # check that learner prints expect_output(info = info, print(lrn), lrn$id) + + # check that param set prints + par.set = getParamSet(lrn) + expect_output(info = info, print(par.set)) + + # check that learner trains, predicts m = train(lrn, task) p = predict(m, task) expect_true(info = info, !is.na(performance(pred = p, task = task))) @@ -108,7 +116,7 @@ testThatLearnerHandlesFactors = function(lrn, task, hyperpars) { d[,f] = as.factor(rep_len(c("a", "b"), length.out = nrow(d))) new.task = changeData(task = task, data = d) - testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars) + testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars) } @@ -126,7 +134,7 @@ testThatLearnerHandlesOrderedFactors = function(lrn, task, hyperpars) { d[,f] = as.ordered(rep_len(c("a", "b", "c"), length.out = nrow(d))) new.task = changeData(task = task, data = d) - testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars) + testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars) } @@ -145,7 +153,7 @@ testThatLearnerHandlesMissings = function(lrn, task, hyperpars) { d[1,f] = NA new.task = changeData(task = task, data = d) - testThatLearnerCanTrainPredict(lrn = lrn, task = task, hyperpars = hyperpars) + testBasicLearnerProperties(lrn = lrn, task = task, hyperpars = hyperpars) } testThatLearnerCanCalculateImportance = function(lrn, task, hyperpars) { diff --git a/tests/testthat/test_learners_all_classif.R b/tests/testthat/test_learners_all_classif.R index 931fd85b89..52df487a08 100644 --- a/tests/testthat/test_learners_all_classif.R +++ b/tests/testthat/test_learners_all_classif.R @@ -25,7 +25,7 @@ test_that("learners work: classif ", { features = getTaskFeatureNames(binaryclass.task)[12:15]) lrns = mylist(task, create = TRUE) lapply(lrns, testThatLearnerParamDefaultsAreInParamSet) - lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars) + lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars) # binary classif with factors lrns = mylist("classif", properties = "factors", create = TRUE) @@ -37,7 +37,7 @@ test_that("learners work: classif ", { # binary classif with prob lrns = mylist(binaryclass.task, properties = "prob", create = TRUE) - lapply(lrns, testThatLearnerCanTrainPredict, task = binaryclass.task, + lapply(lrns, testBasicLearnerProperties, task = binaryclass.task, hyperpars = hyperpars, pred.type = "prob") # binary classif with weights diff --git a/tests/testthat/test_learners_all_clusters.R b/tests/testthat/test_learners_all_clusters.R index 94cfd64dca..8180f478f4 100644 --- a/tests/testthat/test_learners_all_clusters.R +++ b/tests/testthat/test_learners_all_clusters.R @@ -9,12 +9,12 @@ test_that("learners work: cluster", { task = noclass.task lrns = mylist(task, create = TRUE) lapply(lrns, testThatLearnerParamDefaultsAreInParamSet) - lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars) + lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars) # clustering, prob task = subsetTask(noclass.task, subset = 1:20, features = getTaskFeatureNames(noclass.task)[1:2]) lrns = mylist(task, properties = "prob", create = TRUE) - lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars, + lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars, pred.type = "prob") # cluster with weights diff --git a/tests/testthat/test_learners_all_multilabel.R b/tests/testthat/test_learners_all_multilabel.R index a7a5766e75..c56dad3b76 100644 --- a/tests/testthat/test_learners_all_multilabel.R +++ b/tests/testthat/test_learners_all_multilabel.R @@ -8,11 +8,11 @@ test_that("learners work: multilabel", { # multiabel lrns = mylist("multilabel", create = TRUE) lapply(lrns, testThatLearnerParamDefaultsAreInParamSet) - lapply(lrns, testThatLearnerCanTrainPredict, task = multilabel.task, hyperpars = hyperpars) + lapply(lrns, testBasicLearnerProperties, task = multilabel.task, hyperpars = hyperpars) # multilabel, probs lrns = mylist("multilabel", properties = "prob", create = TRUE) - lapply(lrns, testThatLearnerCanTrainPredict, task = multilabel.task, + lapply(lrns, testBasicLearnerProperties, task = multilabel.task, hyperpars = hyperpars, pred.type = "prob") # multilabel, factors diff --git a/tests/testthat/test_learners_all_regr.R b/tests/testthat/test_learners_all_regr.R index 636af5ec19..1f526a4c44 100644 --- a/tests/testthat/test_learners_all_regr.R +++ b/tests/testthat/test_learners_all_regr.R @@ -21,7 +21,7 @@ test_that("learners work: regr ", { # normal regr lrns = mylist("regr", create = TRUE) lapply(lrns, testThatLearnerParamDefaultsAreInParamSet) - lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars) + lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars) # regr with factors lrns = mylist("regr", properties = "factors", create = TRUE) @@ -33,7 +33,7 @@ test_that("learners work: regr ", { # regr with se lrns = mylist(task, properties = "se", create = TRUE) - lapply(lrns, testThatLearnerCanTrainPredict, task = task, hyperpars = hyperpars, + lapply(lrns, testBasicLearnerProperties, task = task, hyperpars = hyperpars, pred.type = "se") # regr with weights diff --git a/tests/testthat/test_learners_all_surv.R b/tests/testthat/test_learners_all_surv.R index 44ca0cceff..78e6d4c83d 100644 --- a/tests/testthat/test_learners_all_surv.R +++ b/tests/testthat/test_learners_all_surv.R @@ -12,7 +12,7 @@ test_that("learners work: surv ", { features = getTaskFeatureNames(surv.task)[c(1,2)]) lrns = mylist("surv", create = TRUE) lapply(lrns, testThatLearnerParamDefaultsAreInParamSet) - lapply(lrns, testThatLearnerCanTrainPredict, task = sub.task, hyperpars = hyperpars) + lapply(lrns, testBasicLearnerProperties, task = sub.task, hyperpars = hyperpars) # survival analysis with factors lrns = mylist("surv", properties = "factors", create = TRUE)