Skip to content

Commit

Permalink
clean code for Boosting::ResetTrainingData.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Jul 4, 2017
1 parent a98b23d commit 1e7ccbb
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 153 deletions.
12 changes: 4 additions & 8 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,10 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual void MergeFrom(const Boosting* other) = 0;

/*!
* \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;

virtual void ResetConfig(const BoostingConfig* config) = 0;

/*!
* \brief Add a validation data
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct IOConfig: public ConfigBase {
int data_random_seed = 1;
std::string data_filename = "";
std::vector<std::string> valid_data_filenames;
int snapshot_freq = 100;
int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
Expand Down
4 changes: 0 additions & 4 deletions src/boosting/dart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ class DART: public GBDT {
sum_weight_ = 0.0f;
}

void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
}
/*!
* \brief one training iteration
*/
Expand Down
273 changes: 160 additions & 113 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ GBDT::GBDT()
boost_from_average_(false) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
{
num_threads_ = omp_get_num_threads();
}
}

GBDT::~GBDT() {
Expand All @@ -64,24 +64,104 @@ GBDT::~GBDT() {

void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
train_data_ = train_data;
iter_ = 0;
num_iteration_for_pred_ = 0;
max_feature_idx_ = 0;
num_class_ = config->num_class;
train_data_ = nullptr;
gbdt_config_ = nullptr;
tree_learner_ = nullptr;
ResetTrainingData(config, train_data, objective_function, training_metrics);
gbdt_config_ = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;

objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
is_constant_hessian_ = objective_function_->IsConstantHessian();
num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
} else {
is_constant_hessian_ = false;
}

tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->device_type, &gbdt_config_->tree_config));

// init tree learner
tree_learner_->Init(train_data_, is_constant_hessian_);

// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();

train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));

num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}
// get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();

// if need bagging, create buffer
ResetBaggingConfig(gbdt_config_.get());

// reset config for tree learner
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_);
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
auto label = train_data_->metadata().label();
if (num_tree_per_iteration_ > 1) {
// multi-class
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
int index = static_cast<int>(label[i]);
CHECK(index < num_tree_per_iteration_);
++cnt_per_class[index];
}
for (int i = 0; i < num_tree_per_iteration_; ++i) {
if (cnt_per_class[i] == num_data_) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary class
data_size_t cnt_pos = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (label[i] > 0) {
++cnt_pos;
}
}
if (cnt_pos == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_pos == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
}
}

void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
if (train_data != train_data_ && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;

objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
Expand All @@ -92,129 +172,96 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
is_constant_hessian_ = false;
}

if (train_data_ != train_data && train_data != nullptr) {
if (tree_learner_ == nullptr) {
tree_learner_ = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(new_config->tree_learner_type, new_config->device_type, &new_config->tree_config));
}
// init tree learner
tree_learner_->Init(train_data, is_constant_hessian_);
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();

// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
training_metrics_.shrink_to_fit();
if (train_data != train_data_) {
train_data_ = train_data;
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_tree_per_iteration_));
train_score_updater_.reset(new ScoreUpdater(train_data_, num_tree_per_iteration_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = (i + num_init_iteration_) * num_tree_per_iteration_ + cur_tree_id;
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
num_data_ = train_data->num_data();
num_data_ = train_data_->num_data();

// create buffer for gradients and hessians
if (objective_function_ != nullptr) {
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size);
hessians_.resize(total_size);
}

// get max feature index
max_feature_idx_ = train_data->num_total_features() - 1;
max_feature_idx_ = train_data_->num_total_features() - 1;
// get label index
label_idx_ = train_data->label_idx();
label_idx_ = train_data_->label_idx();
// get feature names
feature_names_ = train_data->feature_names();
feature_names_ = train_data_->feature_names();

feature_infos_ = train_data_->feature_infos();

ResetBaggingConfig(gbdt_config_.get());

feature_infos_ = train_data->feature_infos();
tree_learner_->ResetTrainingData(train_data);
}
}

if ((train_data_ != train_data && train_data != nullptr)
|| (gbdt_config_ != nullptr && gbdt_config_->bagging_fraction != new_config->bagging_fraction)) {
// if need bagging, create buffer
if (new_config->bagging_fraction < 1.0 && new_config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(new_config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = new_config->bagging_fraction / new_config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data->num_feature_groups(); ++i) {
if (train_data->FeatureGroupIsSparse(i)) {
++sparse_group;
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
void GBDT::ResetConfig(const BoostingConfig* config) {
auto new_config = std::unique_ptr<BoostingConfig>(new BoostingConfig(*config));

early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate;

ResetBaggingConfig(new_config.get());

tree_learner_->ResetConfig(&new_config->tree_config);
gbdt_config_.reset(new_config.release());
}

void GBDT::ResetBaggingConfig(const BoostingConfig* config) {
// if need bagging, create buffer
if (config->bagging_fraction < 1.0 && config->bagging_freq > 0) {
bag_data_cnt_ =
static_cast<data_size_t>(config->bagging_fraction * num_data_);
bag_data_indices_.resize(num_data_);
tmp_indices_.resize(num_data_);
offsets_buf_.resize(num_threads_);
left_cnts_buf_.resize(num_threads_);
right_cnts_buf_.resize(num_threads_);
left_write_pos_buf_.resize(num_threads_);
right_write_pos_buf_.resize(num_threads_);
double average_bag_rate = config->bagging_fraction / config->bagging_freq;
int sparse_group = 0;
for (int i = 0; i < train_data_->num_feature_groups(); ++i) {
if (train_data_->FeatureGroupIsSparse(i)) {
++sparse_group;
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
}
train_data_ = train_data;
if (train_data_ != nullptr) {
// reset config for tree learner
tree_learner_->ResetConfig(&new_config->tree_config);
class_need_train_ = std::vector<bool>(num_tree_per_iteration_, true);
if (objective_function_ != nullptr && objective_function_->SkipEmptyClass()) {
CHECK(num_tree_per_iteration_ == num_class_);
// + 1 here for the binary classification
class_default_output_ = std::vector<double>(num_tree_per_iteration_, 0.0f);
auto label = train_data_->metadata().label();
if (num_tree_per_iteration_ > 1) {
// multi-class
std::vector<data_size_t> cnt_per_class(num_tree_per_iteration_, 0);
for (data_size_t i = 0; i < num_data_; ++i) {
int index = static_cast<int>(label[i]);
CHECK(index < num_tree_per_iteration_);
++cnt_per_class[index];
}
for (int i = 0; i < num_tree_per_iteration_; ++i) {
if (cnt_per_class[i] == num_data_) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(kEpsilon);
} else if (cnt_per_class[i] == 0) {
class_need_train_[i] = false;
class_default_output_[i] = -std::log(1.0f / kEpsilon - 1.0f);
}
}
} else {
// binary class
data_size_t cnt_pos = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (label[i] > 0) {
++cnt_pos;
}
}
if (cnt_pos == 0) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(1.0f / kEpsilon - 1.0f);
} else if (cnt_pos == num_data_) {
class_need_train_[0] = false;
class_default_output_[0] = -std::log(kEpsilon);
}
}
is_use_subset_ = false;
const int group_threshold_usesubset = 100;
const int sparse_group_threshold_usesubset = train_data_->num_feature_groups() / 4;
if (average_bag_rate <= 0.5
&& (train_data_->num_feature_groups() < group_threshold_usesubset || sparse_group < sparse_group_threshold_usesubset)) {
tmp_subset_.reset(new Dataset(bag_data_cnt_));
tmp_subset_->CopyFeatureMapperFrom(train_data_);
is_use_subset_ = true;
Log::Debug("use subset for bagging");
}
} else {
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
tmp_indices_.clear();
is_use_subset_ = false;
}
gbdt_config_.reset(new_config.release());
}

void GBDT::AddValidDataset(const Dataset* valid_data,
Expand Down Expand Up @@ -358,7 +405,7 @@ double LabelAverage(const float* label, data_size_t num_data) {
Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int len) {
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(double);
const double *p1;
Expand Down Expand Up @@ -833,7 +880,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
std::ifstream ifs(filename);
if (ifs.good()) {
std::string origin((std::istreambuf_iterator<char>(ifs)),
(std::istreambuf_iterator<char>()));
(std::istreambuf_iterator<char>()));
output_file.open(filename);
output_file << "#define USE_HARD_CODE 0" << std::endl;
output_file << "#ifndef USE_HARD_CODE" << std::endl;
Expand Down Expand Up @@ -1027,8 +1074,8 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
}
// sort the importance
std::sort(pairs.begin(), pairs.end(),
[](const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
[] (const std::pair<size_t, std::string>& lhs,
const std::pair<size_t, std::string>& rhs) {
return lhs.first > rhs.first;
});
return pairs;
Expand Down

0 comments on commit 1e7ccbb

Please sign in to comment.