Skip to content

Commit

Permalink
[R-package] introduce Dataset methods set_field() and get_field() (#4571
Browse files Browse the repository at this point in the history
)

* [R-package] introduce Dataset set_field() and get_field()

* fix incorrect fields

* update pkgdown

* fix example

* fix another example

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* update docs

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS committed Sep 25, 2021
1 parent 74c7904 commit d462972
Show file tree
Hide file tree
Showing 21 changed files with 344 additions and 69 deletions.
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Expand Up @@ -58,4 +58,4 @@ Imports:
utils
SystemRequirements:
C++11
RoxygenNote: 7.1.1
RoxygenNote: 7.1.2
4 changes: 4 additions & 0 deletions R-package/NAMESPACE
Expand Up @@ -3,10 +3,13 @@
S3method("dimnames<-",lgb.Dataset)
S3method(dim,lgb.Dataset)
S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
export(get_field)
export(getinfo)
export(lgb.Dataset)
export(lgb.Dataset.construct)
Expand All @@ -30,6 +33,7 @@ export(lgb.unloader)
export(lightgbm)
export(readRDS.lgb.Booster)
export(saveRDS.lgb.Booster)
export(set_field)
export(setinfo)
export(slice)
import(methods)
Expand Down
187 changes: 159 additions & 28 deletions R-package/R/lgb.Dataset.R
Expand Up @@ -335,14 +335,17 @@ Dataset <- R6::R6Class(
for (i in seq_along(private$info)) {

p <- private$info[i]
self$setinfo(name = names(p), info = p[[1L]])
self$set_field(
field_name = names(p)
, data = p[[1L]]
)

}

}

# Get label information existence
if (is.null(self$getinfo(name = "label"))) {
if (is.null(self$get_field(field_name = "label"))) {
stop("lgb.Dataset.construct: label should be set")
}

Expand Down Expand Up @@ -452,27 +455,41 @@ Dataset <- R6::R6Class(

},

# Get information
getinfo = function(name) {
warning(paste0(
"Dataset$getinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$get_field() instead."
))
return(
self$get_field(
field_name = name
)
)
},

get_field = function(field_name) {

# Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) {
stop("getinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", "))
if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop(
"Dataset$get_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
}

# Check for info name and handle
if (is.null(private$info[[name]])) {
if (is.null(private$info[[field_name]])) {

if (lgb.is.null.handle(x = private$handle)) {
stop("Cannot perform getinfo before constructing Dataset.")
stop("Cannot perform Dataset$get_field() before constructing Dataset.")
}

# Get field size of info
info_len <- 0L
.Call(
LGBM_DatasetGetFieldSize_R
, private$handle
, name
, field_name
, info_len
)

Expand All @@ -481,7 +498,7 @@ Dataset <- R6::R6Class(

# Get back fields
ret <- NULL
ret <- if (name == "group") {
ret <- if (field_name == "group") {
integer(info_len) # Integer
} else {
numeric(info_len) # Numeric
Expand All @@ -490,47 +507,62 @@ Dataset <- R6::R6Class(
.Call(
LGBM_DatasetGetField_R
, private$handle
, name
, field_name
, ret
)

private$info[[name]] <- ret
private$info[[field_name]] <- ret

}
}

return(private$info[[name]])
return(private$info[[field_name]])

},

# Set information
setinfo = function(name, info) {
warning(paste0(
"Dataset$setinfo() is deprecated and will be removed in a future release. "
, "Use Dataset$set_field() instead."
))
return(
self$set_field(
field_name = name
, data = info
)
)
},

set_field = function(field_name, data) {

# Check if attribute key is in the known attribute list
if (!is.character(name) || length(name) != 1L || !name %in% .INFO_KEYS()) {
stop("setinfo: name must one of the following: ", paste0(sQuote(.INFO_KEYS()), collapse = ", "))
if (!is.character(field_name) || length(field_name) != 1L || !field_name %in% .INFO_KEYS()) {
stop(
"Dataset$set_field(): field_name must one of the following: "
, paste0(sQuote(.INFO_KEYS()), collapse = ", ")
)
}

# Check for type of information
info <- if (name == "group") {
as.integer(info) # Integer
data <- if (field_name == "group") {
as.integer(data) # Integer
} else {
as.numeric(info) # Numeric
as.numeric(data) # Numeric
}

# Store information privately
private$info[[name]] <- info
private$info[[field_name]] <- data

if (!lgb.is.null.handle(x = private$handle) && !is.null(info)) {
if (!lgb.is.null.handle(x = private$handle) && !is.null(data)) {

if (length(info) > 0L) {
if (length(data) > 0L) {

.Call(
LGBM_DatasetSetField_R
, private$handle
, name
, info
, length(info)
, field_name
, data
, length(data)
)

private$version <- private$version + 1L
Expand All @@ -554,7 +586,7 @@ Dataset <- R6::R6Class(
, paste(names(additional_keyword_args), collapse = ", ")
, ". These are ignored and should be removed. "
, "To change the parameters of a Dataset produced by Dataset$slice(), use Dataset$set_params(). "
, "To modify attributes like 'init_score', use Dataset$setinfo(). "
, "To modify attributes like 'init_score', use Dataset$set_field(). "
, "In future releases of lightgbm, this warning will become an error."
))
}
Expand Down Expand Up @@ -1110,7 +1142,7 @@ dimnames.lgb.Dataset <- function(x) {
#'
#' dsub <- lightgbm::slice(dtrain, seq_len(42L))
#' lgb.Dataset.construct(dsub)
#' labels <- lightgbm::getinfo(dsub, "label")
#' labels <- lightgbm::get_field(dsub, "label")
#' }
#' @export
slice <- function(dataset, ...) {
Expand Down Expand Up @@ -1173,6 +1205,8 @@ getinfo <- function(dataset, ...) {
#' @export
getinfo.lgb.Dataset <- function(dataset, name, ...) {

warning("Calling getinfo() on a lgb.Dataset is deprecated. Use get_field() instead.")

additional_args <- list(...)
if (length(additional_args) > 0L) {
warning(paste0(
Expand All @@ -1187,7 +1221,7 @@ getinfo.lgb.Dataset <- function(dataset, name, ...) {
stop("getinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(dataset$getinfo(name = name))
return(dataset$get_field(field_name = name))

}

Expand Down Expand Up @@ -1236,6 +1270,8 @@ setinfo <- function(dataset, ...) {
#' @export
setinfo.lgb.Dataset <- function(dataset, name, info, ...) {

warning("Calling setinfo() on a lgb.Dataset is deprecated. Use set_field() instead.")

additional_args <- list(...)
if (length(additional_args) > 0L) {
warning(paste0(
Expand All @@ -1250,7 +1286,102 @@ setinfo.lgb.Dataset <- function(dataset, name, info, ...) {
stop("setinfo.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(invisible(dataset$setinfo(name = name, info = info)))
return(invisible(dataset$set_field(field_name = name, data = info)))
}

#' @name get_field
#' @title Get one attribute of a \code{lgb.Dataset}
#' @description Get one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to get. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @return requested attribute
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all(labels2 == 1 - labels))
#' }
#' @export
get_field <- function(dataset, field_name) {
UseMethod("get_field")
}

#' @rdname get_field
#' @export
get_field.lgb.Dataset <- function(dataset, field_name) {

# Check if dataset is not a dataset
if (!lgb.is.Dataset(x = dataset)) {
stop("get_field.lgb.Dataset(): input dataset should be an lgb.Dataset object")
}

return(dataset$get_field(field_name = field_name))

}

#' @name set_field
#' @title Set one attribute of a \code{lgb.Dataset} object
#' @description Set one attribute of a \code{lgb.Dataset}
#' @param dataset Object of class \code{lgb.Dataset}
#' @param field_name String with the name of the attribute to set. One of the following.
#' \itemize{
#' \item \code{label}: label lightgbm learns from ;
#' \item \code{weight}: to do a weight rescale ;
#' \item{\code{group}: used for learning-to-rank tasks. An integer vector describing how to
#' group rows together as ordered results from the same set of candidate results to be ranked.
#' For example, if you have a 100-document dataset with \code{group = c(10, 20, 40, 10, 10, 10)},
#' that means that you have 6 groups, where the first 10 records are in the first group,
#' records 11-30 are in the second group, etc.}
#' \item \code{init_score}: initial score is the base prediction lightgbm will boost from.
#' }
#' @param data The data for the field. See examples.
#' @return The \code{lgb.Dataset} you passed in.
#'
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' lgb.Dataset.construct(dtrain)
#'
#' labels <- lightgbm::get_field(dtrain, "label")
#' lightgbm::set_field(dtrain, "label", 1 - labels)
#'
#' labels2 <- lightgbm::get_field(dtrain, "label")
#' stopifnot(all.equal(labels2, 1 - labels))
#' }
#' @export
set_field <- function(dataset, field_name, data) {
UseMethod("set_field")
}

#' @rdname set_field
#' @export
set_field.lgb.Dataset <- function(dataset, field_name, data) {

if (!lgb.is.Dataset(x = dataset)) {
stop("set_field.lgb.Dataset: input dataset should be an lgb.Dataset object")
}

return(invisible(dataset$set_field(field_name = field_name, data = data)))
}

#' @name lgb.Dataset.set.categorical
Expand Down

0 comments on commit d462972

Please sign in to comment.