diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 695da7b6b77..7d613919d8a 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -447,6 +447,12 @@ IO Parameters
- LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``
+- ``max_bin_by_feature`` :raw-html:`🔗︎`, default = ``None``, type = multi-int
+
+ - max number of bins for each feature
+
+ - if not specified, will use ``max_bin`` for all features
+
- ``min_data_in_bin`` :raw-html:`🔗︎`, default = ``3``, type = int, constraints: ``min_data_in_bin > 0``
- minimal number of data inside one bin
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 9244e31c493..e8d7003c589 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -443,6 +443,12 @@ struct Config {
// desc = LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``
int max_bin = 255;
+ // type = multi-int
+ // default = None
+ // desc = max number of bins for each feature
+ // desc = if not specified, will use ``max_bin`` for all features
+ std::vector max_bin_by_feature;
+
// check = >0
// desc = minimal number of data inside one bin
// desc = use this to avoid one-data-one-bin (potential over-fitting)
diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h
index 12dbe6c98ce..92cc201cc62 100644
--- a/include/LightGBM/dataset.h
+++ b/include/LightGBM/dataset.h
@@ -624,6 +624,7 @@ class Dataset {
std::vector feature_penalty_;
bool is_finish_load_;
int max_bin_;
+ std::vector max_bin_by_feature_;
int bin_construct_sample_cnt_;
int min_data_in_bin_;
bool use_missing_;
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 4a88e26723c..8d75b1cde3d 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -218,6 +218,7 @@ std::unordered_set Config::parameter_set({
"cegb_penalty_feature_coupled",
"verbosity",
"max_bin",
+ "max_bin_by_feature",
"min_data_in_bin",
"bin_construct_sample_cnt",
"histogram_pool_size",
@@ -418,6 +419,10 @@ void Config::GetMembersFromString(const std::unordered_map1);
+ if (GetString(params, "max_bin_by_feature", &tmp_str)) {
+ max_bin_by_feature = Common::StringToArray(tmp_str, ',');
+ }
+
GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);
@@ -610,6 +615,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
+ str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp
index ccb72b9ccc7..f201a40a1a7 100644
--- a/src/io/dataset.cpp
+++ b/src/io/dataset.cpp
@@ -318,6 +318,12 @@ void Dataset::Construct(
feature_penalty_.clear();
}
}
+ if (!io_config.max_bin_by_feature.empty()) {
+ CHECK(static_cast(num_total_features_) == io_config.max_bin_by_feature.size());
+ CHECK(*(std::min_element(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end())) > 1);
+ max_bin_by_feature_.resize(num_total_features_);
+ max_bin_by_feature_.assign(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end());
+ }
max_bin_ = io_config.max_bin;
min_data_in_bin_ = io_config.min_data_in_bin;
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
@@ -332,6 +338,9 @@ void Dataset::ResetConfig(const char* parameters) {
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
+ if (param.count("max_bin_by_feature") && io_config.max_bin_by_feature != max_bin_by_feature_) {
+ Log::Warning("Cannot change max_bin_by_feature after constructed Dataset handle.");
+ }
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
@@ -643,7 +652,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(sparse_threshold_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
- + sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
+ + sizeof(double) * num_features_ + sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
@@ -682,6 +691,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
if (ArrayArgs::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
+ if (max_bin_by_feature_.empty()) {
+ ArrayArgs::Assign(&max_bin_by_feature_, -1, num_total_features_);
+ }
+ writer->Write(max_bin_by_feature_.data(), sizeof(int32_t) * num_total_features_);
+ if (ArrayArgs::CheckAll(max_bin_by_feature_, -1)) {
+ max_bin_by_feature_.clear();
+ }
// write feature names
for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast(feature_names_[i].size());
@@ -730,6 +746,10 @@ void Dataset::DumpTextFile(const char* text_filename) {
for (auto i : feature_penalty_) {
fprintf(file, "%lf, ", i);
}
+ fprintf(file, "\nmax_bin_by_feature: ");
+ for (auto i : max_bin_by_feature_) {
+ fprintf(file, "%d, ", i);
+ }
fprintf(file, "\n");
for (auto n : feature_names_) {
fprintf(file, "%s, ", n.c_str());
diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp
index 94b41b26a8d..bb6d252e726 100644
--- a/src/io/dataset_loader.cpp
+++ b/src/io/dataset_loader.cpp
@@ -427,6 +427,23 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
dataset->feature_penalty_.clear();
}
+ if (!config_.max_bin_by_feature.empty()) {
+ CHECK(static_cast(dataset->num_total_features_) == config_.max_bin_by_feature.size());
+ CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
+ dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
+ dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
+ } else {
+ const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast(mem_ptr);
+ dataset->max_bin_by_feature_.clear();
+ for (int i = 0; i < dataset->num_total_features_; ++i) {
+ dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
+ }
+ }
+ mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
+ if (ArrayArgs::CheckAll(dataset->max_bin_by_feature_, -1)) {
+ dataset->max_bin_by_feature_.clear();
+ }
+
// get feature names
dataset->feature_names_.clear();
// write feature names
@@ -544,7 +561,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_.push_back(str_buf.str());
}
}
-
+ if (!config_.max_bin_by_feature.empty()) {
+ CHECK(static_cast(num_col) == config_.max_bin_by_feature.size());
+ CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
+ }
const data_size_t filter_cnt = static_cast(
static_cast(config_.min_data_in_leaf * total_sample_size) / num_data);
if (Network::num_machines() == 1) {
@@ -562,8 +582,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
- bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
- config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ if (config_.max_bin_by_feature.empty()) {
+ bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
+ config_.max_bin, config_.min_data_in_bin, filter_cnt,
+ bin_type, config_.use_missing, config_.zero_as_missing);
+ } else {
+ bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
+ config_.max_bin_by_feature[i], config_.min_data_in_bin,
+ filter_cnt, bin_type, config_.use_missing,
+ config_.zero_as_missing);
+ }
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
@@ -599,8 +627,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
- bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size,
- config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ if (config_.max_bin_by_feature.empty()) {
+ bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
+ total_sample_size, config_.max_bin, config_.min_data_in_bin,
+ filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ } else {
+ bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
+ total_sample_size, config_.max_bin_by_feature[start[rank] + i],
+ config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
+ config_.zero_as_missing);
+ }
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
@@ -831,6 +867,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->num_total_features_ = static_cast(feature_names_.size());
}
+ if (!config_.max_bin_by_feature.empty()) {
+ CHECK(static_cast(dataset->num_total_features_) == config_.max_bin_by_feature.size());
+ CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
+ }
+
// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
@@ -865,8 +906,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
- bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()),
- sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ if (config_.max_bin_by_feature.empty()) {
+ bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()),
+ sample_data.size(), config_.max_bin, config_.min_data_in_bin,
+ filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ } else {
+ bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()),
+ sample_data.size(), config_.max_bin_by_feature[i],
+ config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
+ config_.zero_as_missing);
+ }
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
@@ -902,8 +951,18 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
- bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast(sample_values[start[rank] + i].size()),
- sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ if (config_.max_bin_by_feature.empty()) {
+ bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
+ static_cast(sample_values[start[rank] + i].size()),
+ sample_data.size(), config_.max_bin, config_.min_data_in_bin,
+ filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
+ } else {
+ bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
+ static_cast(sample_values[start[rank] + i].size()),
+ sample_data.size(), config_.max_bin_by_feature[i],
+ config_.min_data_in_bin, filter_cnt, bin_type,
+ config_.use_missing, config_.zero_as_missing);
+ }
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index c7cb6695926..d453cc2ba9b 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -813,6 +813,29 @@ def is_correctly_constrained(learner):
constrained_model = lgb.train(params, trainset)
self.assertTrue(is_correctly_constrained(constrained_model))
+ def test_max_bin_by_feature(self):
+ col1 = np.arange(0, 100)[:, np.newaxis]
+ col2 = np.zeros((100, 1))
+ col2[20:] = 1
+ X = np.concatenate([col1, col2], axis=1)
+ y = np.arange(0, 100)
+ params = {
+ 'objective': 'regression_l2',
+ 'verbose': -1,
+ 'num_leaves': 100,
+ 'min_data_in_leaf': 1,
+ 'min_sum_hessian_in_leaf': 0,
+ 'min_data_in_bin': 1,
+ 'max_bin_by_feature': [100, 2]
+ }
+ lgb_data = lgb.Dataset(X, label=y)
+ est = lgb.train(params, lgb_data, num_boost_round=1)
+ self.assertEqual(len(np.unique(est.predict(X))), 100)
+ params['max_bin_by_feature'] = [2, 100]
+ lgb_data = lgb.Dataset(X, label=y)
+ est = lgb.train(params, lgb_data, num_boost_round=1)
+ self.assertEqual(len(np.unique(est.predict(X))), 3)
+
def test_refit(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)