Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
ja-thomas committed Apr 12, 2018
1 parent 3d998ff commit 96d15fe
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
24 changes: 18 additions & 6 deletions R/autoxgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ autoxgboost = function(task, measure = NULL, control = NULL, iterations = 160L,
if (!is.null(nthread))
pv$nthread = nthread

rinst = makeResampleInstance(makeResampleDesc("Holdout", split = early.stopping.fraction), task)

task.test = subsetTask(task, rinst$test.inds[[1]])
task = subsetTask(task, rinst$train.inds[[1]])
# create base.learner

if (tt == "classif") {

Expand Down Expand Up @@ -131,6 +128,9 @@ autoxgboost = function(task, measure = NULL, control = NULL, iterations = 160L,
} else {
stop("Task must be regression or classification")
}

# Create pipeline

preproc.pipeline = NULLCPO

if (has.cat.feats) {
Expand All @@ -152,18 +152,28 @@ autoxgboost = function(task, measure = NULL, control = NULL, iterations = 160L,

preproc.pipeline %<>>% cpoDropConstants()

task.train = task %>>% preproc.pipeline

# process data and apply pipeline

# split early stopping data
rinst = makeResampleInstance(makeResampleDesc("Holdout", split = early.stopping.fraction), task)
task.test = subsetTask(task, rinst$test.inds[[1]])
task.train = subsetTask(task, rinst$train.inds[[1]])

task.train %<>>% preproc.pipeline
task.test %<>>% retrafo(task.train)
base.learner = setHyperPars(base.learner, early.stopping.data = task.test)

# Optimize

opt = smoof::makeSingleObjectiveFunction(name = "optimizeWrapper",
fn = function(x) {
lrn = setHyperPars(base.learner, par.vals = x)
mod = train(lrn, task.train)
pred = predict(mod, task.test)
nrounds = getBestIteration(mod)

if (tune.threshold && tt == "classif") {
if (tune.threshold && getTaskType(task.train) == "classif") {
tune.res = tuneThreshold(pred = pred, measure = measure)
res = tune.res$perf
attr(res, "extras") = list(nrounds = nrounds, .threshold = tune.res$th)
Expand All @@ -180,6 +190,8 @@ autoxgboost = function(task, measure = NULL, control = NULL, iterations = 160L,
des = generateDesign(n = design.size, par.set)

optim.result = mbo(fun = opt, control = control, design = des, learner = mbo.learner)


lrn = buildFinalLearner(optim.result, objective, predict.type, par.set = par.set,
dummy.cols = dummy.cols, impact.cols = impact.cols, preproc.pipeline = preproc.pipeline)

Expand Down
3 changes: 0 additions & 3 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,3 @@ getThreshold = function(optim.result) {
getBestIteration = function(mod) {
getLearnerModel(mod, more.unwrap = TRUE)$best_iteration
}



0 comments on commit 96d15fe

Please sign in to comment.