Skip to content

Commit

Permalink
fix max_drop. add many checks for parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Dec 29, 2017
1 parent 0af44ac commit 4271082
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
1 change: 0 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ struct TreeConfig: public ConfigBase {
struct BoostingConfig: public ConfigBase {
public:
virtual ~BoostingConfig() {}
double sigmoid = 1.0f;
int output_freq = 1;
bool is_provide_training_metric = false;
int num_iterations = 100;
Expand Down
6 changes: 6 additions & 0 deletions src/boosting/dart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class DART: public GBDT {
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
}
}
} else {
Expand All @@ -105,6 +108,9 @@ class DART: public GBDT {
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextFloat() < drop_rate) {
drop_index_.push_back(num_init_iteration_ + i);
if (drop_index_.size() >= static_cast<size_t>(gbdt_config_->max_drop)) {
break;
}
}
}
}
Expand Down
31 changes: 25 additions & 6 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,14 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename);
GetString(params, "init_score_file", &initscore_filename);
GetInt(params, "verbose", &verbosity);
GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
CHECK(bin_construct_sample_cnt > 0);
GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetDouble(params, "sparse_threshold", &sparse_threshold);
Expand Down Expand Up @@ -290,9 +292,10 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin > 0);
CHECK(min_data_in_leaf >= 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
CHECK(max_conflict_rate >= 0);
GetBool(params, "enable_bundle", &enable_bundle);

GetBool(params, "pred_early_stop", &pred_early_stop);
GetInt(params, "pred_early_stop_freq", &pred_early_stop_freq);
GetDouble(params, "pred_early_stop_margin", &pred_early_stop_margin);
Expand All @@ -304,15 +307,21 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetDouble(params, "gaussian_eta", &gaussian_eta);
CHECK(gaussian_eta > 0);
GetDouble(params, "poisson_max_delta_step", &poisson_max_delta_step);
CHECK(poisson_max_delta_step > 0);
GetInt(params, "max_position", &max_position);
CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
CHECK(num_class > 0);
GetDouble(params, "scale_pos_weight", &scale_pos_weight);
CHECK(scale_pos_weight > 0);
GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
GetBool(params, "reg_sqrt", &reg_sqrt);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
Expand All @@ -331,9 +340,13 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa

void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid);
CHECK(sigmoid > 0);
GetDouble(params, "fair_c", &fair_c);
CHECK(fair_c > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetDouble(params, "alpha", &alpha);
CHECK(alpha > 0 && alpha < 1);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToArray<double>(tmp_str, ',');
Expand Down Expand Up @@ -365,7 +378,8 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "min_data_in_leaf", &min_data_in_leaf);
GetDouble(params, "min_sum_hessian_in_leaf", &min_sum_hessian_in_leaf);
CHECK(min_sum_hessian_in_leaf > 0 || min_data_in_leaf > 0);
CHECK(min_data_in_leaf > 0);
CHECK(min_sum_hessian_in_leaf >= 0);
GetDouble(params, "lambda_l1", &lambda_l1);
CHECK(lambda_l1 >= 0.0f);
GetDouble(params, "lambda_l2", &lambda_l2);
Expand All @@ -380,6 +394,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
GetInt(params, "top_k", &top_k);
CHECK(top_k > 0);
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
GetBool(params, "gpu_use_dp", &gpu_use_dp);
Expand All @@ -397,7 +412,6 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)

void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_iterations", &num_iterations);
GetDouble(params, "sigmoid", &sigmoid);
CHECK(num_iterations >= 0);
GetInt(params, "bagging_seed", &bagging_seed);
GetInt(params, "bagging_freq", &bagging_freq);
Expand All @@ -412,17 +426,22 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class > 0);
GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "drop_rate", &drop_rate);
GetDouble(params, "skip_drop", &skip_drop);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetInt(params, "max_drop", &max_drop);
CHECK(max_drop > 0);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop);
GetDouble(params, "top_rate", &top_rate);
GetDouble(params, "other_rate", &other_rate);
CHECK(top_rate > 0);
CHECK(top_rate > 0);
CHECK(top_rate + top_rate <= 1.0);
GetBool(params, "boost_from_average", &boost_from_average);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetDeviceType(params, &device_type);
GetTreeLearnerType(params, &tree_learner_type);
tree_config.Set(params);
Expand Down

0 comments on commit 4271082

Please sign in to comment.