Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] enable saving Booster with saveRDS() and loading it with readRDS() (fixes #4296) #4685

Merged
merged 42 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4c177a5
idiomatic serialization
david-cortes Oct 14, 2021
ada260b
linter
david-cortes Oct 14, 2021
0d551a0
linter, namespace
david-cortes Oct 14, 2021
088eef3
comments, linter, fix failing test
david-cortes Oct 15, 2021
5e2a922
standardize error messages for null handles
david-cortes Oct 15, 2021
5c1d260
auto-restore handle in more functions
david-cortes Oct 15, 2021
8ed14a6
linter
david-cortes Oct 15, 2021
840de5e
missing declaration
david-cortes Oct 15, 2021
af16b2d
correct wrong signature
david-cortes Oct 15, 2021
9d7e6f8
fix docs
david-cortes Oct 15, 2021
4428a23
Update R-package/R/lgb.train.R
david-cortes Oct 15, 2021
730f2e6
Update R-package/R/lgb.drop_serialized.R
david-cortes Oct 15, 2021
719af93
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 15, 2021
41a75bd
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 15, 2021
9b5de4d
Update R-package/R/lgb.make_serializable.R
david-cortes Oct 15, 2021
1f4aa91
move 'restore_handle' from feature importance to dump method
david-cortes Oct 15, 2021
84af4e7
missing header
david-cortes Oct 15, 2021
25557f7
move arguments order, update docs
david-cortes Oct 15, 2021
ff78dd2
linter
david-cortes Oct 15, 2021
19f3c4a
avoid leaving files in working directory
david-cortes Oct 15, 2021
2f3a334
add test for save_model=NULL
david-cortes Oct 15, 2021
6e7b852
missing comma
david-cortes Oct 15, 2021
617b226
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 16, 2021
8a078f4
Update R-package/src/lightgbm_R.cpp
david-cortes Oct 16, 2021
8e194af
change name of error function
david-cortes Oct 16, 2021
d4c8ef1
update comment
david-cortes Oct 16, 2021
44ca8db
restore old serialization functions but set as deprecated
david-cortes Oct 16, 2021
d6f4c74
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 17, 2021
8d282e4
Update R-package/R/saveRDS.lgb.Booster.R
david-cortes Oct 17, 2021
0817eb0
update docs
david-cortes Oct 17, 2021
f845554
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 26, 2021
51fa088
Update R-package/R/saveRDS.lgb.Booster.R
david-cortes Oct 26, 2021
8522ce7
Update R-package/tests/testthat/test_basic.R
david-cortes Oct 26, 2021
c116270
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 26, 2021
bee5bc1
comments
david-cortes Oct 26, 2021
b0f9f93
fix variable name
david-cortes Oct 26, 2021
2d3a132
restore serialization test for linear models
david-cortes Oct 26, 2021
c534952
Update R-package/R/lightgbm.R
david-cortes Nov 18, 2021
58fd21f
Merge branch 'master' into serial
david-cortes Nov 18, 2021
b1b4e2b
update docs
david-cortes Nov 18, 2021
eb7fd32
fix issues with null terminator
david-cortes Dec 4, 2021
34707ae
solve conflicts
david-cortes Dec 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 41 additions & 13 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@ Booster <- R6::R6Class(

} else if (!is.null(model_str)) {

# Do we have a model_str as character?
if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str")
# Do we have a model_str as character/raw?
if (is.character(model_str))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
model_str <- charToRaw(model_str)
if (!is.raw(model_str)) {
stop("lgb.Booster: Can only use a character/raw vector as model_str")
}

# Create booster from model
Expand Down Expand Up @@ -436,7 +438,7 @@ Booster <- R6::R6Class(
return(invisible(self))
},

save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
Expand All @@ -449,6 +451,9 @@ Booster <- R6::R6Class(
, as.integer(feature_importance_type)
)

if (as_char)
model_str <- rawToChar(model_str)

return(model_str)

},
Expand Down Expand Up @@ -527,17 +532,37 @@ Booster <- R6::R6Class(
return(Predictor$new(modelfile = private$handle))
},

# Used for save
raw = NA,
# Used for serialization
raw = NULL,

# Store serialized raw bytes in model object
save_raw = function() {
if (is.null(self$raw))
self$raw <- self$save_model_to_string(NULL, as_char=FALSE)
return(invisible(NULL))

},

# Save model to temporary file for in-memory saving
save = function() {
drop_raw = function() {
self$raw <- NULL
return(invisible(NULL))
},

# Overwrite model in object
self$raw <- self$save_model_to_string(NULL)
check_null_handle = function() {
return(lgb.is.null.handle(private$handle))
},

restore_handle = function() {
if (self$check_null_handle()) {
if (is.null(self$raw))
stop("LightGBM model is not de-serializable. Try using 'serializable=TRUE'.")
private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
}
return(invisible(NULL))
},

get_handle = function() {
return(private$handle)
}

),
Expand Down Expand Up @@ -784,6 +809,7 @@ predict.lgb.Booster <- function(object,
if (!lgb.is.Booster(x = object)) {
stop("predict.lgb.Booster: object should be an ", sQuote("lgb.Booster"))
}
object$restore_handle()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be moved inside Booster$predict()?

That way, it'll be guaranteed to run regardless of whether someone uses predict(bst, data) or bst$predict().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's placed earlier it can throw an error before doing any other long operations with the data. I also assume that since the R6 methods are not documented, they are meant for internal usage, and a user trying to call them directly would likely need to examine the code in any case.

Copy link
Collaborator

@jameslamb jameslamb Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's placed earlier it can throw an error before doing any other long operations with the data.

The only code between this call and Booster$predict() is conversion of ... to a list and possibly raising a deprecation warning.

I'd prefer to concentrate these $restore_handle() calls in the Booster object as much as possible, to minimize how many places in the package's code need to know about managing the raw model object.

I also assume that since the R6 methods are not documented, they are meant for internal usage

The fact that those methods are not documented is a gap that should be filled (not in this PR, please).

However, all of the Booster's public methods except initialize() are treated as part of the public API of the R package. We treat them that way because other exported functions can return Booster instance. For example, lgb.train() returns a Booster instance, and then user code can call any public methods on that instance without needing to use ::: or reach into $.__enclos_env__$private.


additional_params <- list(...)
if (length(additional_params) > 0L) {
Expand Down Expand Up @@ -815,7 +841,7 @@ predict.lgb.Booster <- function(object,
#' @description Load LightGBM takes in either a file path or model string.
#' If both are provided, Load will default to loading from file
#' @param filename path of model file
#' @param model_str a str containing the model
#' @param model_str a str containing the model (as a `character` or `raw` vector)
#'
#' @return lgb.Booster
#'
Expand Down Expand Up @@ -863,9 +889,11 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
return(invisible(Booster$new(modelfile = filename)))
}

if (is.character(model_str))
model_str <- charToRaw(model_str)
if (model_str_provided) {
if (!is.character(model_str)) {
stop("lgb.load: model_str should be character")
if (!is.raw(model_str)) {
stop("lgb.load: model_str should be a character/raw vector")
}
return(invisible(Booster$new(model_str = model_str)))
}
Expand Down
7 changes: 6 additions & 1 deletion R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ Predictor <- R6::R6Class(
)
private$need_free_handle <- TRUE

} else if (methods::is(modelfile, "lgb.Booster.handle")) {
} else if (methods::is(modelfile, "lgb.Booster.handle") || inherits(modelfile, "externalptr")) {

# Check if model file is a booster handle already
handle <- modelfile
private$need_free_handle <- FALSE

} else if (lgb.is.Booster(modelfile)) {

handle <- modelfile$get_handle()
private$need_free_handle <- FALSE

} else {

stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
Expand Down
5 changes: 5 additions & 0 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ lgb.cv <- function(params = list()
, categorical_feature = NULL
, early_stopping_rounds = NULL
, callbacks = list()
, serializable = TRUE
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
, reset_data = FALSE
, ...
) {
Expand Down Expand Up @@ -456,6 +457,10 @@ lgb.cv <- function(params = list()
})
}

if (serializable) {
lapply(cv_booster$boosters, function(model) model$booster$save_raw())
}

return(cv_booster)

}
Expand Down
17 changes: 17 additions & 0 deletions R-package/R/lgb.drop_serialized.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#' @name lgb.drop_serialized
#' @title Drop serialized raw bytes in a LightGBM model object
#' @description If a LightGBM model object was produced with argument `serializable=TRUE`, the R object will keep
#' a copy of the underlying C++ object as raw bytes, which can be used to reconstruct such object after getting
#' serialized and de-serialized, but at the cost of extra memory usage. If these raw bytes are not needed anymore,
#' they can be dropped through this function in order to save memory. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=TRUE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.make_serializable}.
#' @examples
#' @export
lgb.drop_serialized <- function(model) {
stopifnot(lgb.is.Booster(model))
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
model$drop_raw()
return(invisible(model))
}
1 change: 1 addition & 0 deletions R-package/R/lgb.importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ lgb.importance <- function(model, percentage = TRUE) {
if (!lgb.is.Booster(x = model)) {
stop("'model' has to be an object of class lgb.Booster")
}
model$restore_handle()
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

# Setup importance
tree_dt <- lgb.model.dt.tree(model = model)
Expand Down
17 changes: 17 additions & 0 deletions R-package/R/lgb.make_serializable.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#' @name lgb.make_serializable
#' @title Make a LightGBM object serializable by keeping raw bytes
#' @description If a LightGBM model object was produced with argument `serializable=FALSE`, the R object will not
#' be serializable (e.g. cannot save and load with \code{saveRDS} and \code{readRDS}) as it will lack the raw bytes
#' needed to reconstruct its underlying C++ object. This function can be used to forcibly produce those serialized
#' raw bytes and make the object serializable. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=FALSE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.drop_serialized}.
#' @examples
#' @export
lgb.make_serializable <- function(model) {
stopifnot(lgb.is.Booster(model))
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
model$save_raw()
return(invisible(model))
}
18 changes: 18 additions & 0 deletions R-package/R/lgb.restore_handle.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' @name lgb.restore_handle
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
#' @title Restore the C++ component of a deserialized LGB model
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
#' @description After a LightGBM model object is de-serialized through functions such as \code{save} or
#' \code{saveRDS}, its underlying C++ object will be blank and needs to be restored to able to use it. Such
#' object is restored automatically when calling functions such as \code{predict}, but this function can be
#' used to forcibly restore it beforehand. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was de-serialized and whose underlying C++ object and R handle
#' need to be restored.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
#' @seealso \link{lgb.make_serializable}, \link{lgb.drop_serialized}.
#' @examples
#' @export
lgb.restore_handle <- function(model) {
stopifnot(lgb.is.Booster(model))
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
model$restore_handle()
return(invisible(model))
}
4 changes: 4 additions & 0 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ lgb.train <- function(params = list(),
categorical_feature = NULL,
early_stopping_rounds = NULL,
callbacks = list(),
serializable = TRUE,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
reset_data = FALSE,
...) {

Expand Down Expand Up @@ -395,6 +396,9 @@ lgb.train <- function(params = list(),

}

if (serializable)
booster$save_raw()
david-cortes marked this conversation as resolved.
Show resolved Hide resolved

return(booster)

}
19 changes: 19 additions & 0 deletions R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#' @param params a list of parameters. See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#' the "Parameters" section of the documentation} for a list of parameters and valid values.
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
#' @param serializable whether to make the resulting objects serializable through functions such as
#' \code{save} or \code{saveRDS} (see section "Model serialization").
#' @section Early Stopping:
#'
#' "early stopping" refers to stopping the training process if the model's performance on a given
Expand All @@ -66,6 +68,21 @@
#' in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#' a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#' or \code{objective} (passed into \code{params}).
#' @section Model serialization:
#'
#' LightGBM models objects can be serialized and de-serialized through functions such as \code{save}
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
#' or \code{saveRDS}, but similarly to libraries such as 'xgboost', serialization works a bit differently
#' from typical R objects. In order to make models serializable in R, a copy of the underlying C++ object
#' as serialized raw bytes is produced and stored in the R model object, and when this R object is
#' de-serialized, the underlying C++ model object gets reconstructed from these raw bytes, but will only
#' do so once some function that uses it is called, such as \code{predict}. In order to forcibly
#' reconstruct the C++ object after deserialization (e.g. after calling \code{readRDS} or similar), one
#' can use the function \link{lgb.restore_handle} (for example, if one makes predictions in parallel or in
#' forked processes, it will be faster to restore the handle beforehand).
#'
#' Producing and keeping these raw bytes however uses extra memory, and if they are not required,
#' it is possible to avoid producing them by passing `serializable=FALSE`. In such cases, these raw
#' bytes can be added to the model on demand through function \link{lgb.make_serializable}.
#' @keywords internal
NULL

Expand Down Expand Up @@ -113,6 +130,7 @@ lightgbm <- function(data,
save_name = "lightgbm.model",
init_model = NULL,
callbacks = list(),
serializable = TRUE,
...) {

# validate inputs early to avoid unnecessary computation
Expand All @@ -137,6 +155,7 @@ lightgbm <- function(data,
, "early_stopping_rounds" = early_stopping_rounds
, "init_model" = init_model
, "callbacks" = callbacks
, "serializable" = serializable
)
train_args <- append(train_args, list(...))

Expand Down
62 changes: 0 additions & 62 deletions R-package/R/readRDS.lgb.Booster.R

This file was deleted.