Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
84 lines (76 sloc) 3.43 KB
# FIXME: use learnerparam or ordinary params?
#' Fuse learner with preprocessing.
#' Fuses a base learner with a preprocessing method. Creates a learner object, which can be
#' used like any other learner object, but which internally preprocesses the data as requested.
#' If the train or predict function is called on data / a task, the preprocessing is always performed automatically.
#' @template arg_learner
#' @param train (`function(data, target, args)`)\cr
#' Function to preprocess the data before training.
#' `target` is a string and denotes the target variable in `data`.
#' `args` is a list of further arguments and parameters to influence the
#' preprocessing.
#' Must return a `list(data, control)`, where `data` is the preprocessed
#' data and `control` stores all information necessary to do the preprocessing
#' before predictions.
#' @param predict (`function(data, target, args, control)`)\cr
#' Function to preprocess the data before prediction.
#' `target` is a string and denotes the target variable in `data`.
#' `args` are the args that were passed to `train`.
#' `control` is the object you returned in `train`.
#' Must return the processed data.
#' @param par.set ([ParamHelpers::ParamSet])\cr
#' Parameter set of [ParamHelpers::LearnerParam] objects to describe the
#' parameters in `args`.
#' Default is empty set.
#' @param par.vals ([list])\cr
#' Named list of default values for params in `args` respectively `par.set`.
#' Default is empty list.
#' @return ([Learner]).
#' @family wrapper
#' @export
makePreprocWrapper = function(learner, train, predict, par.set = makeParamSet(), par.vals = list()) {
learner = checkLearner(learner)
assertFunction(train, args = c("data", "target", "args"))
assertFunction(predict, args = c("data", "target", "args", "control"))
assertClass(par.set, classes = "ParamSet")
if (!isProperlyNamed(par.vals)) {
stop("'par.vals' must be a properly named list!")
id = stri_paste(learner$id, "preproc", sep = ".")
x = makeBaseWrapper(id, type = learner$type, next.learner = learner, par.set = par.set,
par.vals = par.vals, learner.subclass = "PreprocWrapper", model.subclass = "PreprocModel")
x$train = train
x$predict = predict
#' @export
trainLearner.PreprocWrapper = function(.learner, .task, .subset = NULL, ...) {
pvs = .learner$par.vals
pp = .learner$train(data = getTaskData(.task, .subset, = "matrix"),
target = getTaskTargetNames(.task), args = pvs)
# FIXME: why is the order important?
if (!(is.list(pp) && length(pp) == 2L && all(names(pp) == c("data", "control")) &&$data) && is.list(pp$control))) {
stop("Preprocessing train must result in list wil elements data[data.frame] and control[list]!")
.task = changeData(.task, pp$data)
# we have already subsetted!
m = train(.learner$next.learner, .task)
# FIXME: time and can we do this better?
# we dont really kow which subset was used after preprocessing and features will have changed
x = makeChainModel(next.model = m, cl = "PreprocModel")
x$control = pp$control
#' @export
predictLearner.PreprocWrapper = function(.learner, .model, .newdata, ...) {
.newdata = .learner$predict(.newdata, .model$task.desc$target,
.learner$par.vals, .model$learner.model$control)
if (! {
stop("Preprocessing must result in a data.frame!")
NextMethod(.newdata = .newdata)
You can’t perform that action at this time.