Skip to content

Commit

Permalink
[python] [R-package] Use the same address when updated label/weight/q…
Browse files Browse the repository at this point in the history
…uery (#2662)

* Update metadata.cpp

* add version for training set, for efficiently update label/weight/... during training.

* Update lgb.Booster.R
  • Loading branch information
guolinke committed Jan 14, 2020
1 parent 350d56d commit 82886ba
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 9 deletions.
9 changes: 9 additions & 0 deletions R-package/R/lgb.Booster.R
Expand Up @@ -55,6 +55,7 @@ Booster <- R6::R6Class(

# Create private booster information
private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version
private$num_dataset <- 1L
private$init_predictor <- train_set$.__enclos_env__$private$predictor

Expand Down Expand Up @@ -207,6 +208,12 @@ Booster <- R6::R6Class(
# Perform boosting update iteration
update = function(train_set = NULL, fobj = NULL) {

if (is.null(train_set)) {
if (private$train_set$.__enclos_env__$private$version != private$train_set_version) {
train_set <- private$train_set
}
}

# Check if training set is not null
if (!is.null(train_set)) {

Expand All @@ -230,6 +237,7 @@ Booster <- R6::R6Class(

# Store private train set
private$train_set <- train_set
private$train_set_version <- train_set$.__enclos_env__$private$version

}

Expand Down Expand Up @@ -497,6 +505,7 @@ Booster <- R6::R6Class(
eval_names = NULL,
higher_better_inner_eval = NULL,
set_objective_to_none = FALSE,
train_set_version = 0L,
# Predict data
inner_predict = function(idx) {

Expand Down
4 changes: 4 additions & 0 deletions R-package/R/lgb.Dataset.R
Expand Up @@ -89,6 +89,7 @@ Dataset <- R6::R6Class(
private$free_raw_data <- free_raw_data
private$used_indices <- sort(used_indices, decreasing = FALSE)
private$info <- info
private$version <- 0L

},

Expand Down Expand Up @@ -503,6 +504,8 @@ Dataset <- R6::R6Class(
, length(info)
)

private$version <- private$version + 1L

}

}
Expand Down Expand Up @@ -638,6 +641,7 @@ Dataset <- R6::R6Class(
free_raw_data = TRUE,
used_indices = NULL,
info = NULL,
version = 0L,

# Get handle
get_handle = function() {
Expand Down
11 changes: 10 additions & 1 deletion python-package/lightgbm/basic.py
Expand Up @@ -771,6 +771,7 @@ def __init__(self, data, label=None, reference=None,
self.params_back_up = None
self.feature_penalty = None
self.monotone_constraints = None
self.version = 0

def __del__(self):
try:
Expand Down Expand Up @@ -1233,6 +1234,7 @@ def set_field(self, field_name, data):
ptr_data,
ctypes.c_int(len(data)),
ctypes.c_int(type_data)))
self.version += 1
return self

def get_field(self, field_name):
Expand Down Expand Up @@ -1740,6 +1742,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
self.train_set_version = train_set.version
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
Expand Down Expand Up @@ -2076,7 +2079,12 @@ def update(self, train_set=None, fobj=None):
Whether the update was successfully finished.
"""
# need reset training data
if train_set is not None and train_set is not self.train_set:
if train_set is None and self.train_set_version != self.train_set.version:
train_set = self.train_set
is_the_same_train_set = False
else:
is_the_same_train_set = train_set is self.train_set and self.train_set_version == train_set.version
if train_set is not None and not is_the_same_train_set:
if not isinstance(train_set, Dataset):
raise TypeError('Training data should be Dataset instance, met {}'
.format(type(train_set).__name__))
Expand All @@ -2088,6 +2096,7 @@ def update(self, train_set=None, fobj=None):
self.handle,
self.train_set.construct().handle))
self.__inner_predict_buffer[0] = None
self.train_set_version = self.train_set.version
is_finished = ctypes.c_int(0)
if fobj is None:
if self.__set_objective_to_none:
Expand Down
15 changes: 7 additions & 8 deletions src/io/metadata.cpp
Expand Up @@ -290,9 +290,9 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
if ((len % num_data_) != 0) {
Log::Fatal("Initial score size doesn't match data size");
}
if (!init_score_.empty()) { init_score_.clear(); }
if (init_score_.empty()) { init_score_.resize(len); }
num_init_score_ = len;
init_score_ = std::vector<double>(len);

#pragma omp parallel for schedule(static)
for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = Common::AvoidInf(init_score[i]);
Expand All @@ -308,8 +308,8 @@ void Metadata::SetLabel(const label_t* label, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("Length of label is not same with #data");
}
if (!label_.empty()) { label_.clear(); }
label_ = std::vector<label_t>(num_data_);
if (label_.empty()) { label_.resize(num_data_); }

#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
label_[i] = Common::AvoidInf(label[i]);
Expand All @@ -327,9 +327,9 @@ void Metadata::SetWeights(const label_t* weights, data_size_t len) {
if (num_data_ != len) {
Log::Fatal("Length of weights is not same with #data");
}
if (!weights_.empty()) { weights_.clear(); }
if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_;
weights_ = std::vector<label_t>(num_weights_);

#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]);
Expand All @@ -354,9 +354,8 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
if (num_data_ != sum) {
Log::Fatal("Sum of query counts is not same with #data");
}
if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
num_queries_ = len;
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
query_boundaries_.resize(num_queries_ + 1);
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
Expand Down

0 comments on commit 82886ba

Please sign in to comment.