-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] Updated lgb.Booster.R with keyword arguments #3496
Changes from 22 commits
aad38bb
00f0524
8ab9fe5
07ac370
219b646
b6d1185
49ddde1
c629688
81a2e60
3a88050
9b81ce8
46fca40
9929cae
0e88211
1667eca
e558aff
4f2cc94
04254c9
621c9fb
c98d527
7037e82
69cc5a1
f2be0d8
d3fd5e3
0cedcd7
40708da
e84a31a
028b3f6
986906d
14e8ebe
6a944c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -39,12 +39,12 @@ Booster <- R6::R6Class( | |||||
# Check if training dataset is not null | ||||||
if (!is.null(train_set)) { | ||||||
# Check if training dataset is lgb.Dataset or not | ||||||
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { | ||||||
if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) { | ||||||
stop("lgb.Booster: Can only use lgb.Dataset as training data") | ||||||
} | ||||||
train_set_handle <- train_set$.__enclos_env__$private$get_handle() | ||||||
params <- modifyList(params, train_set$get_params()) | ||||||
params_str <- lgb.params2str(params) | ||||||
params_str <- lgb.params2str(params = params) | ||||||
# Store booster handle | ||||||
handle <- lgb.call( | ||||||
"LGBM_BoosterCreate_R" | ||||||
|
@@ -84,7 +84,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Create booster from model | ||||||
handle <- lgb.call( | ||||||
"LGBM_BoosterCreateFromModelfile_R" | ||||||
fun_name = "LGBM_BoosterCreateFromModelfile_R" | ||||||
, ret = handle | ||||||
, lgb.c_str(modelfile) | ||||||
) | ||||||
|
@@ -98,7 +98,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Create booster from model | ||||||
handle <- lgb.call( | ||||||
"LGBM_BoosterLoadModelFromString_R" | ||||||
fun_name = "LGBM_BoosterLoadModelFromString_R" | ||||||
, ret = handle | ||||||
, lgb.c_str(model_str) | ||||||
) | ||||||
|
@@ -116,7 +116,7 @@ Booster <- R6::R6Class( | |||||
}) | ||||||
|
||||||
# Check whether the handle was created properly if it was not stopped earlier by a stop call | ||||||
if (lgb.is.null.handle(handle)) { | ||||||
if (isTRUE(lgb.is.null.handle(handle))) { | ||||||
|
||||||
stop("lgb.Booster: cannot create Booster handle") | ||||||
|
||||||
|
@@ -127,7 +127,7 @@ Booster <- R6::R6Class( | |||||
private$handle <- handle | ||||||
private$num_class <- 1L | ||||||
private$num_class <- lgb.call( | ||||||
"LGBM_BoosterGetNumClasses_R" | ||||||
fun_name = "LGBM_BoosterGetNumClasses_R" | ||||||
, ret = private$num_class | ||||||
, private$handle | ||||||
) | ||||||
|
@@ -149,7 +149,7 @@ Booster <- R6::R6Class( | |||||
add_valid = function(data, name) { | ||||||
|
||||||
# Check if data is lgb.Dataset | ||||||
if (!lgb.check.r6.class(data, "lgb.Dataset")) { | ||||||
if (!(lgb.Dataset = lgb.check.r6.class(data, "lgb.Dataset"))) { | ||||||
stop("lgb.Booster.add_valid: Can only use lgb.Dataset as validation data") | ||||||
} | ||||||
|
||||||
|
@@ -189,7 +189,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Append parameters | ||||||
params <- append(params, list(...)) | ||||||
params_str <- lgb.params2str(params) | ||||||
params_str <- lgb.params2str(params = params) | ||||||
|
||||||
# Reset parameters | ||||||
lgb.call( | ||||||
|
@@ -216,7 +216,7 @@ Booster <- R6::R6Class( | |||||
if (!is.null(train_set)) { | ||||||
|
||||||
# Check if training set is lgb.Dataset | ||||||
if (!lgb.check.r6.class(train_set, "lgb.Dataset")) { | ||||||
if (!lgb.check.r6.class(object = train_set, name = "lgb.Dataset")) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It looks like this |
||||||
stop("lgb.Booster.update: Only can use lgb.Dataset as training data") | ||||||
} | ||||||
|
||||||
|
@@ -268,7 +268,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Return custom boosting gradient/hessian | ||||||
ret <- lgb.call( | ||||||
"LGBM_BoosterUpdateOneIterCustom_R" | ||||||
fun_name = "LGBM_BoosterUpdateOneIterCustom_R" | ||||||
, ret = NULL | ||||||
, private$handle | ||||||
, gpair$grad | ||||||
|
@@ -311,7 +311,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
cur_iter <- 0L | ||||||
lgb.call( | ||||||
"LGBM_BoosterGetCurrentIteration_R" | ||||||
fun_name = "LGBM_BoosterGetCurrentIteration_R" | ||||||
, ret = cur_iter | ||||||
, private$handle | ||||||
) | ||||||
|
@@ -323,7 +323,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
upper_bound <- 0.0 | ||||||
lgb.call( | ||||||
"LGBM_BoosterGetUpperBoundValue_R" | ||||||
fun_name = "LGBM_BoosterGetUpperBoundValue_R" | ||||||
, ret = upper_bound | ||||||
, private$handle | ||||||
) | ||||||
|
@@ -335,7 +335,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
lower_bound <- 0.0 | ||||||
lgb.call( | ||||||
"LGBM_BoosterGetLowerBoundValue_R" | ||||||
fun_name = "LGBM_BoosterGetLowerBoundValue_R" | ||||||
, ret = lower_bound | ||||||
, private$handle | ||||||
) | ||||||
|
@@ -346,7 +346,7 @@ Booster <- R6::R6Class( | |||||
eval = function(data, name, feval = NULL) { | ||||||
|
||||||
# Check if dataset is lgb.Dataset | ||||||
if (!lgb.check.r6.class(data, "lgb.Dataset")) { | ||||||
if (!lgb.check.r6.class(object = data, name = "lgb.Dataset")) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It looks like the |
||||||
stop("lgb.Booster.eval: Can only use lgb.Dataset to eval") | ||||||
} | ||||||
|
||||||
|
@@ -387,7 +387,11 @@ Booster <- R6::R6Class( | |||||
} | ||||||
|
||||||
# Evaluate data | ||||||
private$inner_eval(name, data_idx, feval) | ||||||
private$inner_eval( | ||||||
data_name = name | ||||||
, data_idx = data_idx | ||||||
, feval = feval | ||||||
) | ||||||
|
||||||
}, | ||||||
|
||||||
|
@@ -429,7 +433,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Save booster model | ||||||
lgb.call( | ||||||
"LGBM_BoosterSaveModel_R" | ||||||
fun_name = "LGBM_BoosterSaveModel_R" | ||||||
, ret = NULL | ||||||
, private$handle | ||||||
, as.integer(num_iteration) | ||||||
|
@@ -450,7 +454,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Return model string | ||||||
return(lgb.call.return.str( | ||||||
"LGBM_BoosterSaveModelToString_R" | ||||||
fun_name = "LGBM_BoosterSaveModelToString_R" | ||||||
, private$handle | ||||||
, as.integer(num_iteration) | ||||||
, as.integer(feature_importance_type) | ||||||
|
@@ -467,7 +471,7 @@ Booster <- R6::R6Class( | |||||
} | ||||||
|
||||||
lgb.call.return.str( | ||||||
"LGBM_BoosterDumpModel_R" | ||||||
fun_name = "LGBM_BoosterDumpModel_R" | ||||||
, private$handle | ||||||
, as.integer(num_iteration) | ||||||
, as.integer(feature_importance_type) | ||||||
|
@@ -496,7 +500,16 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Predict on new data | ||||||
predictor <- Predictor$new(private$handle, ...) | ||||||
predictor$predict(data, start_iteration, num_iteration, rawscore, predleaf, predcontrib, header, reshape) | ||||||
predictor$predict( | ||||||
data = data | ||||||
, start_iteration = start_iteration | ||||||
, num_iteration = num_iteration | ||||||
, rawscore = rawscore | ||||||
, predleaf = predleaf | ||||||
, predcontrib = predcontrib | ||||||
, header = header | ||||||
, reshape = reshape | ||||||
) | ||||||
|
||||||
}, | ||||||
|
||||||
|
@@ -554,7 +567,7 @@ Booster <- R6::R6Class( | |||||
# Store predictions | ||||||
npred <- 0L | ||||||
npred <- lgb.call( | ||||||
"LGBM_BoosterGetNumPredict_R" | ||||||
fun_name = "LGBM_BoosterGetNumPredict_R" | ||||||
, ret = npred | ||||||
, private$handle | ||||||
, as.integer(idx - 1L) | ||||||
|
@@ -587,7 +600,7 @@ Booster <- R6::R6Class( | |||||
|
||||||
# Get evaluation names | ||||||
names <- lgb.call.return.str( | ||||||
"LGBM_BoosterGetEvalNames_R" | ||||||
fun_name = "LGBM_BoosterGetEvalNames_R" | ||||||
, private$handle | ||||||
) | ||||||
|
||||||
|
@@ -631,7 +644,7 @@ Booster <- R6::R6Class( | |||||
# Create evaluation values | ||||||
tmp_vals <- numeric(length(private$eval_names)) | ||||||
tmp_vals <- lgb.call( | ||||||
"LGBM_BoosterGetEval_R" | ||||||
fun_name = "LGBM_BoosterGetEval_R" | ||||||
, ret = tmp_vals | ||||||
, private$handle | ||||||
, as.integer(data_idx - 1L) | ||||||
|
@@ -758,14 +771,14 @@ predict.lgb.Booster <- function(object, | |||||
|
||||||
# Return booster predictions | ||||||
object$predict( | ||||||
data | ||||||
, start_iteration | ||||||
, num_iteration | ||||||
, rawscore | ||||||
, predleaf | ||||||
, predcontrib | ||||||
, header | ||||||
, reshape | ||||||
data = data | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
, start_iteration = start_iteration | ||||||
, num_iteration = num_iteration | ||||||
, rawscore = rawscore | ||||||
, predleaf = predleaf | ||||||
, predcontrib = predcontrib | ||||||
, header = header | ||||||
, reshape = reshape | ||||||
, ... | ||||||
) | ||||||
} | ||||||
|
@@ -873,7 +886,10 @@ lgb.save <- function(booster, filename, num_iteration = NULL) { | |||||
} | ||||||
|
||||||
# Store booster | ||||||
invisible(booster$save_model(filename, num_iteration)) | ||||||
invisible(booster$save_model( | ||||||
filename = filename | ||||||
, num_iteration = num_iteration | ||||||
)) | ||||||
|
||||||
} | ||||||
|
||||||
|
@@ -915,7 +931,7 @@ lgb.dump <- function(booster, num_iteration = NULL) { | |||||
} | ||||||
|
||||||
# Return booster at requested iteration | ||||||
booster$dump_model(num_iteration) | ||||||
booster$dump_model(num_iteration = num_iteration) | ||||||
|
||||||
} | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check is still incorrect. Please accept this suggestion in your browser.