Skip to content
This repository has been archived by the owner on Jan 6, 2022. It is now read-only.

Commit

Permalink
Merge be7a6f8 into 662654c
Browse files Browse the repository at this point in the history
  • Loading branch information
jakob-r committed Sep 7, 2017
2 parents 662654c + be7a6f8 commit 29a0689
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 3 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(print,ParConfig)
S3method(trainLearner,HyperoptWrapper)
export(downloadParConfig)
export(downloadParConfigs)
export(generateHyperControl)
Expand All @@ -21,6 +22,7 @@ export(getParConfigParVals)
export(getTaskDictionary)
export(hyperopt)
export(makeHyperControl)
export(makeHyperoptWrapper)
export(makeParConfig)
export(setHyperControlMeasures)
export(setHyperControlMlrControl)
Expand All @@ -40,3 +42,4 @@ import(lhs)
import(methods)
import(mlr)
import(stringi)
importFrom(utils,getFromNamespace)
58 changes: 58 additions & 0 deletions R/HyperoptWrapper.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#' @title Fuse learner with mlrHyperopt tuning.
#'
#' @description
#' Fuses an mlr base learner with mlrHyperopt tuning.
#' Creates a learner object, which can be used like any other learner object.
#' If the train function is called on it, \code{\link{hyperopt}} is invoked to select an optimal set of hyperparameter values.
#' Finally, a model is fitted on the complete training data with these optimal hyperparameters and returned.
#'
#' @template arg_learner
#' @template arg_parconfig
#' @template arg_hypercontrol
#' @template arg_showinfo
#' @return [\code{\link{Learner}}].
#' @export
#' @family tune
#' @family wrapper
#' @examples
#' \donttest{
#' task = makeClassifTask(data = iris, target = "Species")
#' lrn = makeLearner("classif.svm")
#' lrn = makeHyperoptWrapper(lrn)
#' mod = train(lrn, task)
#' print(getTuneResult(mod))
#' # nested resampling for evaluation
#' # we also extract tuned hyper pars in each iteration
#' r = resample(lrn, task, cv3, extract = getTuneResult)
#' getNestedTuneResultsX(r)
#' }
#' @importFrom utils getFromNamespace
makeHyperoptWrapper = function(learner, par.config = NULL, hyper.control = NULL, show.info = getMlrOptions()$show.info) {
learner = checkLearner(learner)
id = stri_paste(learner$id, "hyperopt", sep = ".")
# more or less just an empty dummy control
makeTuneControl = getFromNamespace("makeTuneControl", "mlr")
makeOptWrapper = getFromNamespace("makeOptWrapper", "mlr")
control = makeTuneControl(same.resampling.instance = FALSE, cl = "TuneControlHyperopt")
x = makeOptWrapper(id = id, learner = learner, resampling = NULL, measures = NULL, par.set = NULL, bit.names = character(0L), bits.to.features = function(){}, control = control, show.info = show.info, learner.subclass = c("HyperoptWrapper", "TuneWrapper"), model.subclass = "TuneModel")
x$hyper.control = hyper.control
x$par.config = par.config
return(x)
}

#' @export
trainLearner.HyperoptWrapper = function(.learner, .task, .subset = NULL, ...) {
.task = subsetTask(.task, .subset)
or = hyperopt(task = .task, learner = .learner$next.learner, par.config = .learner$par.config, hyper.control = .learner$hyper.control)
lrn = or$learner
or$learner = NULL
if ("DownsampleWrapper" %in% class(.learner$next.learner) && !is.null(.learner$control$final.dw.perc) && !is.null(getHyperPars(lrn)$dw.perc) && getHyperPars(lrn)$dw.perc < 1) {
messagef("Train model on %f on data.", .learner$control$final.dw.perc)
lrn = setHyperPars(lrn, par.vals = list(dw.perc = .learner$control$final.dw.perc))
}
m = train(lrn, .task)
makeChainModel = getFromNamespace("makeChainModel", "mlr")
x = makeChainModel(next.model = m, cl = "TuneModel")
x$opt.result = or
return(x)
}
4 changes: 1 addition & 3 deletions R/hyperopt.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#' If no learner is given the learner referenced in the \code{par.config} will be used, if available.
#' @template arg_parconfig
#' @template arg_hypercontrol
#' @param show.info [\code{logical(1)}]\cr
#' Print verbose output on console?
#' Default is set via \code{\link{configureMlr}}.
#' @template arg_showinfo
#' @return [\code{\link[mlr]{TuneResult}}]
#' @import mlr
#' @examples
Expand Down
3 changes: 3 additions & 0 deletions man-roxygen/arg_showinfo.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#' @param show.info [\code{logical(1)}]\cr
#' Print verbose output on console?
#' Default is set via \code{\link{configureMlr}}.
45 changes: 45 additions & 0 deletions man/makeHyperoptWrapper.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions tests/testthat/test_HyperoptWrapper.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
context("HyperoptWrapper")

test_that("hyperoptWrapper works", {
mlr::configureMlr(show.info = FALSE, show.learner.output = FALSE)
lrn = makeLearner("classif.svm")
lrn2 = makeHyperoptWrapper(lrn)
task = iris.task
res = resample(learner = lrn2, task = task, resampling = cv2, extract = getTuneResult)
expect_class(res$extract[[1]], "TuneResult")
expect_data_frame(getNestedTuneResultsX(res))
expect_data_frame(getNestedTuneResultsOptPathDf(res))

# some random workflow
# triggers Random Search
hyper.control = makeHyperControl(
mlr.control = makeTuneControlRandom(maxit = 10),
resampling = makeResampleDesc("Holdout"),
measures = list(auc)
)
par.config = generateParConfig(lrn, sonar.task)
par.set = getParConfigParSet(par.config)
par.set = filterParams(par.set, ids = "cost")
par.config = setParConfigParSet(par.config, par.set)
par.config = setParConfigParVals(par.config, par.vals = list())
lrn2 = makeHyperoptWrapper(learner = lrn, par.config = par.config, hyper.control = hyper.control)
res = resample(learner = lrn2, task = sonar.task, resampling = cv2, extract = getTuneResult)
expect_class(res$extract[[1]], "TuneResult")
expect_data_frame(getNestedTuneResultsX(res))
expect_data_frame(getNestedTuneResultsOptPathDf(res))
})

0 comments on commit 29a0689

Please sign in to comment.