diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 695da7b6b77..7d613919d8a 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -447,6 +447,12 @@ IO Parameters - LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255`` +- ``max_bin_by_feature`` :raw-html:`🔗︎`, default = ``None``, type = multi-int + + - max number of bins for each feature + + - if not specified, will use ``max_bin`` for all features + - ``min_data_in_bin`` :raw-html:`🔗︎`, default = ``3``, type = int, constraints: ``min_data_in_bin > 0`` - minimal number of data inside one bin diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 9244e31c493..e8d7003c589 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -443,6 +443,12 @@ struct Config { // desc = LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255`` int max_bin = 255; + // type = multi-int + // default = None + // desc = max number of bins for each feature + // desc = if not specified, will use ``max_bin`` for all features + std::vector max_bin_by_feature; + // check = >0 // desc = minimal number of data inside one bin // desc = use this to avoid one-data-one-bin (potential over-fitting) diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 12dbe6c98ce..92cc201cc62 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -624,6 +624,7 @@ class Dataset { std::vector feature_penalty_; bool is_finish_load_; int max_bin_; + std::vector max_bin_by_feature_; int bin_construct_sample_cnt_; int min_data_in_bin_; bool use_missing_; diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 4a88e26723c..8d75b1cde3d 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -218,6 +218,7 @@ std::unordered_set Config::parameter_set({ "cegb_penalty_feature_coupled", "verbosity", "max_bin", + "max_bin_by_feature", "min_data_in_bin", "bin_construct_sample_cnt", "histogram_pool_size", @@ -418,6 +419,10 @@ void Config::GetMembersFromString(const std::unordered_map1); + if (GetString(params, "max_bin_by_feature", &tmp_str)) { + max_bin_by_feature = Common::StringToArray(tmp_str, ','); + } + GetInt(params, "min_data_in_bin", &min_data_in_bin); CHECK(min_data_in_bin >0); @@ -610,6 +615,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n"; str_buf << "[verbosity: " << verbosity << "]\n"; str_buf << "[max_bin: " << max_bin << "]\n"; + str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n"; str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n"; str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n"; str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n"; diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index ccb72b9ccc7..f201a40a1a7 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -318,6 +318,12 @@ void Dataset::Construct( feature_penalty_.clear(); } } + if (!io_config.max_bin_by_feature.empty()) { + CHECK(static_cast(num_total_features_) == io_config.max_bin_by_feature.size()); + CHECK(*(std::min_element(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end())) > 1); + max_bin_by_feature_.resize(num_total_features_); + max_bin_by_feature_.assign(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end()); + } max_bin_ = io_config.max_bin; min_data_in_bin_ = io_config.min_data_in_bin; bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt; @@ -332,6 +338,9 @@ void Dataset::ResetConfig(const char* parameters) { if (param.count("max_bin") && io_config.max_bin != max_bin_) { Log::Warning("Cannot change max_bin after constructed Dataset handle."); } + if (param.count("max_bin_by_feature") && io_config.max_bin_by_feature != max_bin_by_feature_) { + Log::Warning("Cannot change max_bin_by_feature after constructed Dataset handle."); + } if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) { Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle."); } @@ -643,7 +652,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_) + sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(sparse_threshold_) + 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_ - + sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2; + + sizeof(double) * num_features_ + sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2; // size of feature names for (int i = 0; i < num_total_features_; ++i) { size_of_header += feature_names_[i].size() + sizeof(int); @@ -682,6 +691,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { if (ArrayArgs::CheckAll(feature_penalty_, 1.0)) { feature_penalty_.clear(); } + if (max_bin_by_feature_.empty()) { + ArrayArgs::Assign(&max_bin_by_feature_, -1, num_total_features_); + } + writer->Write(max_bin_by_feature_.data(), sizeof(int32_t) * num_total_features_); + if (ArrayArgs::CheckAll(max_bin_by_feature_, -1)) { + max_bin_by_feature_.clear(); + } // write feature names for (int i = 0; i < num_total_features_; ++i) { int str_len = static_cast(feature_names_[i].size()); @@ -730,6 +746,10 @@ void Dataset::DumpTextFile(const char* text_filename) { for (auto i : feature_penalty_) { fprintf(file, "%lf, ", i); } + fprintf(file, "\nmax_bin_by_feature: "); + for (auto i : max_bin_by_feature_) { + fprintf(file, "%d, ", i); + } fprintf(file, "\n"); for (auto n : feature_names_) { fprintf(file, "%s, ", n.c_str()); diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 94b41b26a8d..bb6d252e726 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -427,6 +427,23 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b dataset->feature_penalty_.clear(); } + if (!config_.max_bin_by_feature.empty()) { + CHECK(static_cast(dataset->num_total_features_) == config_.max_bin_by_feature.size()); + CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1); + dataset->max_bin_by_feature_.resize(dataset->num_total_features_); + dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end()); + } else { + const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast(mem_ptr); + dataset->max_bin_by_feature_.clear(); + for (int i = 0; i < dataset->num_total_features_; ++i) { + dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]); + } + } + mem_ptr += sizeof(int32_t) * (dataset->num_total_features_); + if (ArrayArgs::CheckAll(dataset->max_bin_by_feature_, -1)) { + dataset->max_bin_by_feature_.clear(); + } + // get feature names dataset->feature_names_.clear(); // write feature names @@ -544,7 +561,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, feature_names_.push_back(str_buf.str()); } } - + if (!config_.max_bin_by_feature.empty()) { + CHECK(static_cast(num_col) == config_.max_bin_by_feature.size()); + CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1); + } const data_size_t filter_cnt = static_cast( static_cast(config_.min_data_in_leaf * total_sample_size) / num_data); if (Network::num_machines() == 1) { @@ -562,8 +582,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, bin_type = BinType::CategoricalBin; } bin_mappers[i].reset(new BinMapper()); - bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, - config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + if (config_.max_bin_by_feature.empty()) { + bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, + config_.max_bin, config_.min_data_in_bin, filter_cnt, + bin_type, config_.use_missing, config_.zero_as_missing); + } else { + bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, + config_.max_bin_by_feature[i], config_.min_data_in_bin, + filter_cnt, bin_type, config_.use_missing, + config_.zero_as_missing); + } OMP_LOOP_EX_END(); } OMP_THROW_EX(); @@ -599,8 +627,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, bin_type = BinType::CategoricalBin; } bin_mappers[i].reset(new BinMapper()); - bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size, - config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + if (config_.max_bin_by_feature.empty()) { + bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], + total_sample_size, config_.max_bin, config_.min_data_in_bin, + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + } else { + bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], + total_sample_size, config_.max_bin_by_feature[start[rank] + i], + config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, + config_.zero_as_missing); + } OMP_LOOP_EX_END(); } OMP_THROW_EX(); @@ -831,6 +867,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, dataset->num_total_features_ = static_cast(feature_names_.size()); } + if (!config_.max_bin_by_feature.empty()) { + CHECK(static_cast(dataset->num_total_features_) == config_.max_bin_by_feature.size()); + CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1); + } + // check the range of label_idx, weight_idx and group_idx CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_); @@ -865,8 +906,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, bin_type = BinType::CategoricalBin; } bin_mappers[i].reset(new BinMapper()); - bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()), - sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + if (config_.max_bin_by_feature.empty()) { + bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()), + sample_data.size(), config_.max_bin, config_.min_data_in_bin, + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + } else { + bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()), + sample_data.size(), config_.max_bin_by_feature[i], + config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, + config_.zero_as_missing); + } OMP_LOOP_EX_END(); } OMP_THROW_EX(); @@ -902,8 +951,18 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, bin_type = BinType::CategoricalBin; } bin_mappers[i].reset(new BinMapper()); - bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast(sample_values[start[rank] + i].size()), - sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + if (config_.max_bin_by_feature.empty()) { + bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), + static_cast(sample_values[start[rank] + i].size()), + sample_data.size(), config_.max_bin, config_.min_data_in_bin, + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + } else { + bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), + static_cast(sample_values[start[rank] + i].size()), + sample_data.size(), config_.max_bin_by_feature[i], + config_.min_data_in_bin, filter_cnt, bin_type, + config_.use_missing, config_.zero_as_missing); + } OMP_LOOP_EX_END(); } OMP_THROW_EX(); diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index c7cb6695926..d453cc2ba9b 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -813,6 +813,29 @@ def is_correctly_constrained(learner): constrained_model = lgb.train(params, trainset) self.assertTrue(is_correctly_constrained(constrained_model)) + def test_max_bin_by_feature(self): + col1 = np.arange(0, 100)[:, np.newaxis] + col2 = np.zeros((100, 1)) + col2[20:] = 1 + X = np.concatenate([col1, col2], axis=1) + y = np.arange(0, 100) + params = { + 'objective': 'regression_l2', + 'verbose': -1, + 'num_leaves': 100, + 'min_data_in_leaf': 1, + 'min_sum_hessian_in_leaf': 0, + 'min_data_in_bin': 1, + 'max_bin_by_feature': [100, 2] + } + lgb_data = lgb.Dataset(X, label=y) + est = lgb.train(params, lgb_data, num_boost_round=1) + self.assertEqual(len(np.unique(est.predict(X))), 100) + params['max_bin_by_feature'] = [2, 100] + lgb_data = lgb.Dataset(X, label=y) + est = lgb.train(params, lgb_data, num_boost_round=1) + self.assertEqual(len(np.unique(est.predict(X))), 3) + def test_refit(self): X, y = load_breast_cancer(True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)