Skip to content

Commit

Permalink
fix #865
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 27, 2017
1 parent 603bffc commit 5e1a513
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
21 changes: 14 additions & 7 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
feature_infos_ = train_data_->feature_infos();

// if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get());
ResetBaggingConfig(gbdt_config_.get(), true);

// reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
Expand Down Expand Up @@ -211,7 +211,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*

feature_infos_ = train_data_->feature_infos();

ResetBaggingConfig(gbdt_config_.get());
ResetBaggingConfig(gbdt_config_.get(), true);

tree_learner_->ResetTrainingData(train_data);
}
Expand All @@ -222,13 +222,13 @@ void GBDT::ResetConfig(const BoostingConfig* config) {
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) {
ResetBaggingConfig(new_config.get());
ResetBaggingConfig(new_config.get(), false);
tree_learner_->ResetConfig(&new_config->tree_config);
}
gbdt_config_.reset(new_config.release());
}

void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
void GBDT::ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
Expand All @@ -252,8 +252,10 @@ void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
if (tmp_subset_ == nullptr || is_change_dataset) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
}
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
Expand All @@ -263,6 +265,10 @@ void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
tmp_indices_.clear();
is_use_subset_ = false;
}

if (is_change_dataset) {
need_re_bagging_ = true;
}
}

void GBDT::AddValidDataset(const Dataset* valid_data,
Expand Down Expand Up @@ -322,7 +328,8 @@ data_size_t GBDT::BaggingHelper(Random& cur_rand, data_size_t start, data_size_t

void GBDT::Bagging(int iter) {
// if need bagging
if (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) {
if ( (bag_data_cnt_ < num_data_ && iter % gbdt_config_->bagging_freq == 0) || need_re_bagging_) {
need_re_bagging_ = false;
const data_size_t min_inner_size = 1000;
data_size_t inner_size = (num_data_ + num_threads_ - 1) / num_threads_;
if (inner_size < min_inner_size) { inner_size = min_inner_size; }
Expand Down
3 changes: 2 additions & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class GBDT: public Boosting {
virtual const char* SubModelName() const override { return "tree"; }

protected:
void ResetBaggingConfig(const BoostingConfig* config);
void ResetBaggingConfig(const BoostingConfig* config, bool is_change_dataset);
/*!
* \brief Implement bagging logic
* \param iter Current interation
Expand Down Expand Up @@ -388,6 +388,7 @@ class GBDT: public Boosting {
bool is_constant_hessian_;
std::unique_ptr<ObjectiveFunction> loaded_objective_;
bool average_output_;
bool need_re_bagging_ = false;
};

} // namespace LightGBM
Expand Down

0 comments on commit 5e1a513

Please sign in to comment.