Skip to content

Commit

Permalink
fix a bug when bagging with reset_config (#2149)
Browse files Browse the repository at this point in the history
* fix a bug when bagging with reset_config

* clean code
  • Loading branch information
guolinke committed May 6, 2019
1 parent 2c41d15 commit 46d2147
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 159 deletions.
9 changes: 1 addition & 8 deletions include/LightGBM/objective_function.h
Expand Up @@ -43,18 +43,11 @@ class ObjectiveFunction {

virtual bool IsRenewTreeOutput() const { return false; }

virtual double RenewTreeOutput(double ori_output, const double*,
virtual double RenewTreeOutput(double ori_output, std::function<double(const label_t*, int)> residual_getter,
const data_size_t*,
const data_size_t*,
data_size_t) const { return ori_output; }

virtual double RenewTreeOutput(double ori_output, double,
const data_size_t*,
const data_size_t*,
data_size_t) const {
return ori_output;
}

virtual double BoostFromScore(int /*class_id*/) const { return 0.0; }

virtual bool ClassNeedTrain(int /*class_id*/) const { return true; }
Expand Down
5 changes: 1 addition & 4 deletions include/LightGBM/tree_learner.h
Expand Up @@ -77,12 +77,9 @@ class TreeLearner {
*/
virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;

virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;

virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;

TreeLearner() = default;
/*! \brief Disable copy */
TreeLearner& operator=(const TreeLearner&) = delete;
Expand Down
13 changes: 9 additions & 4 deletions src/boosting/gbdt.cpp
Expand Up @@ -364,7 +364,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {

if (new_tree->num_leaves() > 1) {
should_continue = true;
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, train_score_updater_->score() + bias,
auto score_ptr = train_score_updater_->score() + bias;
auto residual_getter = [score_ptr](const label_t* label, int i) {return static_cast<double>(label[i]) - score_ptr[i]; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_);
// shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_);
Expand Down Expand Up @@ -688,6 +690,11 @@ void GBDT::ResetConfig(const Config* config) {
void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
need_re_bagging_ = false;
if (!is_change_dataset &&
config_.get() != nullptr && config_->bagging_fraction == config->bagging_fraction && config_->bagging_freq == config->bagging_freq) {
return;
}
bag_data_cnt_ =
static_cast<data_size_t>(config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
Expand Down Expand Up @@ -719,9 +726,7 @@ void GBDT::ResetBaggingConfig(const Config* config, bool is_change_dataset) {
Log::Debug("Use subset for bagging");
}

if (is_change_dataset) {
need_re_bagging_ = true;
}
need_re_bagging_ = true;

if (is_use_subset_ && bag_data_cnt_ < num_data_) {
if (objective_function_ == nullptr) {
Expand Down
4 changes: 3 additions & 1 deletion src/boosting/rf.hpp
Expand Up @@ -130,7 +130,9 @@ class RF : public GBDT {
}

if (new_tree->num_leaves() > 1) {
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, init_scores_[cur_tree_id],
double pred = init_scores_[cur_tree_id];
auto residual_getter = [pred](const label_t* label, int i) {return static_cast<double>(label[i]) - pred; };
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, residual_getter,
num_data_, bag_data_indices_.data(), bag_data_cnt_);
if (std::fabs(init_scores_[cur_tree_id]) > kEpsilon) {
new_tree->AddBias(init_scores_[cur_tree_id]);
Expand Down
109 changes: 13 additions & 96 deletions src/objective/regression_objective.hpp
Expand Up @@ -232,62 +232,30 @@ class RegressionL1loss: public RegressionL2loss {

bool IsRenewTreeOutput() const override { return true; }

double RenewTreeOutput(double, const double* pred,
double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define data_reader(i) (residual_getter(label_,index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}
}

double RenewTreeOutput(double, double pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
PercentileFun(double, data_reader, num_data_in_leaf, alpha);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
Expand Down Expand Up @@ -552,60 +520,29 @@ class RegressionQuantileloss : public RegressionL2loss {

bool IsRenewTreeOutput() const override { return true; }

double RenewTreeOutput(double, const double* pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
}
}
}

double RenewTreeOutput(double, double pred,
double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
if (weights_ == nullptr) {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define data_reader(i) (residual_getter(label_,index_mapper[i]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
PercentileFun(double, data_reader, num_data_in_leaf, alpha_);
#undef data_reader
}
} else {
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (weights_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (weights_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha_);
#undef data_reader
Expand Down Expand Up @@ -684,39 +621,19 @@ class RegressionMAPELOSS : public RegressionL1loss {

bool IsRenewTreeOutput() const override { return true; }

double RenewTreeOutput(double, const double* pred,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred[index_mapper[i]])
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred[bagging_mapper[index_mapper[i]]])
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
}
}

double RenewTreeOutput(double, double pred,
double RenewTreeOutput(double, std::function<double(const label_t*, int)> residual_getter,
const data_size_t* index_mapper,
const data_size_t* bagging_mapper,
data_size_t num_data_in_leaf) const override {
const double alpha = 0.5;
if (bagging_mapper == nullptr) {
#define data_reader(i) (label_[index_mapper[i]] - pred)
#define data_reader(i) (residual_getter(label_,index_mapper[i]))
#define weight_reader(i) (label_weight_[index_mapper[i]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
#undef weight_reader
} else {
#define data_reader(i) (label_[bagging_mapper[index_mapper[i]]] - pred)
#define data_reader(i) (residual_getter(label_,bagging_mapper[index_mapper[i]]))
#define weight_reader(i) (label_weight_[bagging_mapper[index_mapper[i]]])
WeightedPercentileFun(double, data_reader, weight_reader, num_data_in_leaf, alpha);
#undef data_reader
Expand Down
44 changes: 2 additions & 42 deletions src/treelearner/serial_tree_learner.cpp
Expand Up @@ -851,7 +851,7 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
}


void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
Expand All @@ -869,47 +869,7 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output);
} else {
CHECK(num_machines > 1);
tree->SetLeafOutput(i, 0.0);
n_nozeroworker_perleaf[i] = 0;
}
}
if (num_machines > 1) {
std::vector<double> outputs(tree->num_leaves());
for (int i = 0; i < tree->num_leaves(); ++i) {
outputs[i] = static_cast<double>(tree->LeafOutput(i));
}
Network::GlobalSum(outputs);
Network::GlobalSum(n_nozeroworker_perleaf);
for (int i = 0; i < tree->num_leaves(); ++i) {
tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
}
}
}
}

void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const {
if (obj != nullptr && obj->IsRenewTreeOutput()) {
CHECK(tree->num_leaves() <= data_partition_->num_leaves());
const data_size_t* bag_mapper = nullptr;
if (total_num_data != num_data_) {
CHECK(bag_cnt == num_data_);
bag_mapper = bag_indices;
}
std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1);
int num_machines = Network::num_machines();
#pragma omp parallel for schedule(static)
for (int i = 0; i < tree->num_leaves(); ++i) {
const double output = static_cast<double>(tree->LeafOutput(i));
data_size_t cnt_leaf_data = 0;
auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
if (cnt_leaf_data > 0) {
// bag_mapper[index_mapper[i]]
const double new_output = obj->RenewTreeOutput(output, prediction, index_mapper, bag_mapper, cnt_leaf_data);
const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper, bag_mapper, cnt_leaf_data);
tree->SetLeafOutput(i, new_output);
} else {
CHECK(num_machines > 1);
Expand Down
5 changes: 1 addition & 4 deletions src/treelearner/serial_tree_learner.h
Expand Up @@ -74,10 +74,7 @@ class SerialTreeLearner: public TreeLearner {
}
}

void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;

void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, double prediction,
void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const override;

protected:
Expand Down

0 comments on commit 46d2147

Please sign in to comment.