Skip to content

Commit

Permalink
[R-package] Add $raw to LightGBM (#347)
Browse files Browse the repository at this point in the history
* Update lgb.Booster.R

* Add saveRDS (manual)

* Add documentation

* Change arguments passed from debug.

* Add readRDS, change way of saving.

* Change documentation.

* Add better documentation.
  • Loading branch information
Laurae2 authored and guolinke committed Mar 17, 2017
1 parent cc62b1c commit 06a915a
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 1 deletion.
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ export(lgb.plot.interpretation)
export(lgb.save)
export(lgb.train)
export(lightgbm)
export(readRDS.lgb.Booster)
export(saveRDS.lgb.Booster)
export(setinfo)
export(slice)
import(methods)
Expand Down
9 changes: 8 additions & 1 deletion R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ Booster <- R6Class(
predictor <- Predictor$new(private$handle)
predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)
},
to_predictor = function() { Predictor$new(private$handle) }
to_predictor = function() { Predictor$new(private$handle) },
raw = NA,
save = function() {
temp <- tempfile()
lgb.save(self, temp)
self$raw <- readChar(temp, file.info(temp)$size)
file.remove(temp)
}
),
private = list(
handle = NULL,
Expand Down
42 changes: 42 additions & 0 deletions R-package/R/readRDS.lgb.Booster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#' readRDS for lgb.Booster models
#'
#' Attemps to load a model using RDS.
#'
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param refhook a hook function for handling reference objects.
#'
#' @return an R object.
#'
#' @examples
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label=train$label)
#' data(agaricus.test, package='lightgbm')
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' params <- list(objective="regression", metric="l2")
#' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' saveRDS.lgb.Booster(model, "model.rds")
#' new_model <- readRDS.lgb.Booster("model.rds")
#' }
#' @export

readRDS.lgb.Booster <- function(file = "", refhook = NULL) {

object <- readRDS(file = file, refhook = refhook)
if (!is.na(object$raw)) {
temp <- tempfile()
write(object$raw, temp)
object2 <- lgb.load(temp)
file.remove(temp)
object2$best_iter <- object$best_iter
object2$record_evals <- object$record_evals
return(object2)
} else {
return(object)
}

}
41 changes: 41 additions & 0 deletions R-package/R/saveRDS.lgb.Booster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' saveRDS for lgb.Booster models
#'
#' Attemps to save a model using RDS. Has an additional parameter (\code{raw}) which decides whether to save the raw model or not.
#'
#' @param object R object to serialize.
#' @param file a connection or the name of the file where the R object is saved to or read from.
#' @param ascii a logical. If TRUE or NA, an ASCII representation is written; otherwise (default), a binary one is used. See the comments in the help for save.
#' @param version the workspace format version to use. \code{NULL} specifies the current default version (2). Versions prior to 2 are not supported, so this will only be relevant when there are later versions.
#' @param compress a logical specifying whether saving to a named file is to use "gzip" compression, or one of \code{"gzip"}, \code{"bzip2"} or \code{"xz"} to indicate the type of compression to be used. Ignored if file is a connection.
#' @param refhook a hook function for handling reference objects.
#' @param raw whether to save the model in a raw variable or not, recommended to leave it to \code{TRUE}.
#'
#' @return NULL invisibly.
#'
#' @examples
#' \dontrun{
#' library(lightgbm)
#' data(agaricus.train, package='lightgbm')
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label=train$label)
#' data(agaricus.test, package='lightgbm')
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label=test$label)
#' params <- list(objective="regression", metric="l2")
#' valids <- list(test=dtest)
#' model <- lgb.train(params, dtrain, 100, valids, min_data=1, learning_rate=1, early_stopping_rounds=10)
#' saveRDS.lgb.Booster(model, "model.rds")
#' }
#' @export

saveRDS.lgb.Booster <- function(object, file = "", ascii = FALSE, version = NULL, compress = TRUE, refhook = NULL, raw = TRUE) {

if (is.na(object$raw) & (raw)) {
object$save()
saveRDS(object, file = file, ascii = ascii, version = version, compress = compress, refhook = refhook)
object$raw <- NA
} else {
saveRDS(object, file = file, ascii = ascii, version = version, compress = compress, refhook = refhook)
}

}
36 changes: 36 additions & 0 deletions R-package/man/readRDS.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions R-package/man/saveRDS.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 06a915a

Please sign in to comment.