Skip to content

Commit

Permalink
[R-package] avoid unnecessary computation and add tests for Dataset s…
Browse files Browse the repository at this point in the history
…et_reference() method (#4587)

* [R-package] avoid unnecessary computation in Dataset set_reference() method

* re-arrange conditions

* do more validation upfront and add tests

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update R-package/tests/testthat/test_dataset.R

Co-authored-by: Nikita Titov <nekit94-12@hotmail.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
3 people committed Sep 10, 2021
1 parent 79463df commit a08c37f
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 17 deletions.
27 changes: 10 additions & 17 deletions R-package/R/lgb.Dataset.R
Expand Up @@ -663,34 +663,27 @@ Dataset <- R6::R6Class(
# Set reference
set_reference = function(reference) {

# Set known references
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)

# Check for identical references
# setting reference to this same Dataset object doesn't require any changes
if (identical(private$reference, reference)) {
return(invisible(self))
}

# Check for empty data
# changing the reference removes the Dataset object on the C++ side, so it should only
# be done if you still have the raw_data available, so that the new Dataset can be reconstructed
if (is.null(private$raw_data)) {

stop("set_reference: cannot set reference after freeing raw data,
please set ", sQuote("free_raw_data = FALSE"), " when you construct lgb.Dataset")

}

# Check for non-existing reference
if (!is.null(reference)) {

# Reference is unknown
if (!lgb.is.Dataset(reference)) {
stop("set_reference: Can only use lgb.Dataset as a reference")
}

if (!lgb.is.Dataset(reference)) {
stop("set_reference: Can only use lgb.Dataset as a reference")
}

# Set known references
self$set_categorical_feature(categorical_feature = reference$.__enclos_env__$private$categorical_feature)
self$set_colnames(colnames = reference$get_colnames())
private$set_predictor(predictor = reference$.__enclos_env__$private$predictor)

# Store reference
private$reference <- reference

Expand Down
116 changes: 116 additions & 0 deletions R-package/tests/testthat/test_dataset.R
@@ -1,5 +1,9 @@
context("testing lgb.Dataset functionality")

data(agaricus.train, package = "lightgbm")
train_data <- agaricus.train$data[seq_len(1000L), ]
train_label <- agaricus.train$label[seq_len(1000L)]

data(agaricus.test, package = "lightgbm")
test_data <- agaricus.test$data[1L:100L, ]
test_label <- agaricus.test$label[1L:100L]
Expand Down Expand Up @@ -74,6 +78,118 @@ test_that("Dataset$slice() supports passing Dataset attributes through '...'", {
expect_identical(dsub1$getinfo("init_score"), init_score)
})

test_that("Dataset$set_reference() on a constructed Dataset fails if raw data has been freed", {
dtrain <- lgb.Dataset(train_data, label = train_label)
dtrain$construct()
dtest <- lgb.Dataset(test_data, label = test_label)
dtest$construct()
expect_error({
dtest$set_reference(dtrain)
}, regexp = "cannot set reference after freeing raw data")
})

test_that("Dataset$set_reference() fails if reference is not a Dataset", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
)
expect_error({
dtrain$set_reference(reference = data.frame(x = rnorm(10L)))
}, regexp = "Can only use lgb.Dataset as a reference")

# passing NULL when the Dataset already has a reference raises an error
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
)
dtrain$set_reference(dtest)
expect_error({
dtrain$set_reference(reference = NULL)
}, regexp = "Can only use lgb.Dataset as a reference")
})

test_that("Dataset$set_reference() setting reference to the same Dataset has no side effects", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()

cat_features_before <- dtrain$.__enclos_env__$private$categorical_feature
colnames_before <- dtrain$get_colnames()
predictor_before <- dtrain$.__enclos_env__$private$predictor

dtrain$set_reference(dtrain)
expect_identical(
cat_features_before
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
colnames_before
, dtrain$get_colnames()
)
expect_identical(
predictor_before
, dtrain$.__enclos_env__$private$predictor
)
})

test_that("Dataset$set_reference() updates categorical_feature, colnames, and predictor", {
dtrain <- lgb.Dataset(
train_data
, label = train_label
, free_raw_data = FALSE
, categorical_feature = c(2L, 3L)
)
dtrain$construct()
bst <- Booster$new(
train_set = dtrain
, params = list(verbose = -1L)
)
dtrain$.__enclos_env__$private$predictor <- bst$to_predictor()

test_original_feature_names <- paste0("feature_col_", seq_len(ncol(test_data)))
dtest <- lgb.Dataset(
test_data
, label = test_label
, free_raw_data = FALSE
, colnames = test_original_feature_names
)
dtest$construct()

# at this point, dtest should not have categorical_feature
expect_null(dtest$.__enclos_env__$private$predictor)
expect_null(dtest$.__enclos_env__$private$categorical_feature)
expect_identical(
dtest$get_colnames()
, test_original_feature_names
)

dtest$set_reference(dtrain)

# after setting reference to dtrain, those attributes should have dtrain's values
expect_is(dtest$.__enclos_env__$private$predictor, "lgb.Predictor")
expect_identical(
dtest$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
, dtrain$.__enclos_env__$private$predictor$.__enclos_env__$private$handle
)
expect_identical(
dtest$.__enclos_env__$private$categorical_feature
, dtrain$.__enclos_env__$private$categorical_feature
)
expect_identical(
dtest$get_colnames()
, dtrain$get_colnames()
)
expect_false(
identical(dtest$get_colnames(), test_original_feature_names)
)
})

test_that("lgb.Dataset: colnames", {
dtest <- lgb.Dataset(test_data, label = test_label)
expect_equal(colnames(dtest), colnames(test_data))
Expand Down

0 comments on commit a08c37f

Please sign in to comment.