Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enabled quantile regression in gbm #2603

Merged
merged 14 commits into from Jun 25, 2019
1 change: 1 addition & 0 deletions NEWS.md
Expand Up @@ -11,6 +11,7 @@
See `?regr.randomForest` for more details.
`regr.ranger` relies on the functions provided by the package ("jackknife" and "infjackknife" (default))
(@jakob-r, #1784)
- `regr.gbm` now supports `quantile distribution` (@bthieurmel, #2603)

## functions - general
- `getClassWeightParam()` now also works for Wrapper* Models and ensemble models (@ja-thomas, #891)
Expand Down
26 changes: 20 additions & 6 deletions R/RLearner_regr_gbm.R
Expand Up @@ -4,7 +4,7 @@ makeRLearner.regr.gbm = function() {
cl = "regr.gbm",
package = "gbm",
par.set = makeParamSet(
makeDiscreteLearnerParam(id = "distribution", default = "gaussian", values = c("gaussian", "laplace", "poisson", "tdist")),
makeDiscreteLearnerParam(id = "distribution", default = "gaussian", values = c("gaussian", "laplace", "poisson", "tdist", "quantile")),
# FIXME default for distribution in gbm() is bernoulli
makeIntegerLearnerParam(id = "n.trees", default = 100L, lower = 1L),
makeIntegerLearnerParam(id = "cv.folds", default = 0L),
Expand All @@ -13,6 +13,8 @@ makeRLearner.regr.gbm = function() {
makeNumericLearnerParam(id = "shrinkage", default = 0.001, lower = 0),
makeNumericLearnerParam(id = "bag.fraction", default = 0.5, lower = 0, upper = 1),
makeNumericLearnerParam(id = "train.fraction", default = 1, lower = 0, upper = 1),
makeNumericLearnerParam(id = "alpha", default = 0.5, lower = 0, upper = 1,
requires = quote(distribution == "quantile")),
makeLogicalLearnerParam(id = "keep.data", default = TRUE, tunable = FALSE),
makeLogicalLearnerParam(id = "verbose", default = FALSE, tunable = FALSE)
),
Expand All @@ -28,13 +30,25 @@ makeRLearner.regr.gbm = function() {
#' @export
trainLearner.regr.gbm = function(.learner, .task, .subset, .weights = NULL, ...) {
f = getTaskFormula(.task)
if (is.null(.weights)) {
f = getTaskFormula(.task)
gbm::gbm(f, data = getTaskData(.task, .subset), ...)

params = list(...)
if("alpha" %in% names(params)) {
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after if and I would check whether params$alpha is NULL for consistency with how it's set below.

alpha = params$alpha
params$alpha = NULL
} else {
f = getTaskFormula(.task)
gbm::gbm(f, data = getTaskData(.task, .subset), weights = .weights, ...)
alpha = 0.5
}
if(params$distribution %in% "quantile") {
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after if and should that check be simply ==?

params$distribution = list(name = "quantile", alpha = alpha)
}
params$formula = f
params$data = getTaskData(.task, .subset)

if (!is.null(.weights)) {
params$weights = .weights
}

do.call(gbm::gbm, params)
}

pat-s marked this conversation as resolved.
Show resolved Hide resolved
#' @export
Expand Down
11 changes: 8 additions & 3 deletions tests/testthat/test_regr_gbm.R
Expand Up @@ -5,16 +5,17 @@ test_that("regr_gbm", {

parset.list = list(
list(),
list(n.trees = 600),
list(interaction.depth = 2)
list(n.trees = 600, distribution = "gaussian"),
list(interaction.depth = 2, distribution = "gaussian"),
list(distribution = list(name = "quantile", alpha = 0.2))
)


old.predicts.list = list()

for (i in seq_along(parset.list)) {
parset = parset.list[[i]]
pars = list(regr.formula, data = regr.train, distribution = "gaussian")
pars = list(regr.formula, data = regr.train)
pars = c(pars, parset)
set.seed(getOption("mlr.debug.seed"))
capture.output({
Expand All @@ -25,6 +26,10 @@ test_that("regr_gbm", {
old.predicts.list[[i]] = p
}

# Different way to pass quantile distribution in mlr
parset.list[[4]]$distribution = "quantile"
parset.list[[4]]$alpha = 0.2

testSimpleParsets("regr.gbm", regr.df, regr.target, regr.train.inds, old.predicts.list, parset.list)
})

Expand Down