Skip to content

Commit

Permalink
fix: xgboost nrounds default value (#282)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Jan 23, 2024
1 parent 1caa762 commit 0e41d2c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Config/testthat/edition: 3
Encoding: UTF-8
NeedsCompilation: no
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'aaa.R'
'LearnerClassifCVGlmnet.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

S3method(default_values,LearnerClassifRanger)
S3method(default_values,LearnerClassifSVM)
S3method(default_values,LearnerClassifXgboost)
S3method(default_values,LearnerRegrRanger)
S3method(default_values,LearnerRegrSVM)
S3method(default_values,LearnerRegrXgboost)
export(LearnerClassifCVGlmnet)
export(LearnerClassifGlmnet)
export(LearnerClassifKKNN)
Expand Down
9 changes: 9 additions & 0 deletions R/LearnerClassifXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,14 @@ LearnerClassifXgboost = R6Class("LearnerClassifXgboost",
)
)

#' @export
default_values.LearnerClassifXgboost = function(x, search_space, task, ...) { # nolint
special_defaults = list(
nrounds = 1L
)
defaults = insert_named(default_values(x$param_set), special_defaults)
defaults[search_space$ids()]
}

#' @include aaa.R
learners[["classif.xgboost"]] = LearnerClassifXgboost
10 changes: 10 additions & 0 deletions R/LearnerRegrXgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,15 @@ LearnerRegrXgboost = R6Class("LearnerRegrXgboost",
)
)

#' @export
default_values.LearnerRegrXgboost = function(x, search_space, task, ...) { # nolint
special_defaults = list(
nrounds = 1L
)
defaults = insert_named(default_values(x$param_set), special_defaults)
defaults[search_space$ids()]
}


#' @include aaa.R
learners[["regr.xgboost"]] = LearnerRegrXgboost

0 comments on commit 0e41d2c

Please sign in to comment.