Skip to content

Commit

Permalink
Max bin by feature (#2190)
Browse files Browse the repository at this point in the history
* Add parameter max_bin_by_feature.

* Fix minor bug.

* Fix minor bug.

* Fix calculation of header size for writing binary file.

* Fix style issues.

* Fix python style issue.

* Fix test and python style issue.
  • Loading branch information
btrotta authored and guolinke committed Jul 8, 2019
1 parent 1bd15b9 commit 291752d
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/Parameters.rst
Expand Up @@ -447,6 +447,12 @@ IO Parameters

- LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``

- ``max_bin_by_feature`` :raw-html:`<a id="max_bin_by_feature" title="Permalink to this parameter" href="#max_bin_by_feature">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = multi-int

- max number of bins for each feature

- if not specified, will use ``max_bin`` for all features

- ``min_data_in_bin`` :raw-html:`<a id="min_data_in_bin" title="Permalink to this parameter" href="#min_data_in_bin">&#x1F517;&#xFE0E;</a>`, default = ``3``, type = int, constraints: ``min_data_in_bin > 0``

- minimal number of data inside one bin
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/config.h
Expand Up @@ -443,6 +443,12 @@ struct Config {
// desc = LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``
int max_bin = 255;

// type = multi-int
// default = None
// desc = max number of bins for each feature
// desc = if not specified, will use ``max_bin`` for all features
std::vector<int32_t> max_bin_by_feature;

// check = >0
// desc = minimal number of data inside one bin
// desc = use this to avoid one-data-one-bin (potential over-fitting)
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/dataset.h
Expand Up @@ -624,6 +624,7 @@ class Dataset {
std::vector<double> feature_penalty_;
bool is_finish_load_;
int max_bin_;
std::vector<int32_t> max_bin_by_feature_;
int bin_construct_sample_cnt_;
int min_data_in_bin_;
bool use_missing_;
Expand Down
6 changes: 6 additions & 0 deletions src/io/config_auto.cpp
Expand Up @@ -218,6 +218,7 @@ std::unordered_set<std::string> Config::parameter_set({
"cegb_penalty_feature_coupled",
"verbosity",
"max_bin",
"max_bin_by_feature",
"min_data_in_bin",
"bin_construct_sample_cnt",
"histogram_pool_size",
Expand Down Expand Up @@ -418,6 +419,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin >1);

if (GetString(params, "max_bin_by_feature", &tmp_str)) {
max_bin_by_feature = Common::StringToArray<int32_t>(tmp_str, ',');
}

GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);

Expand Down Expand Up @@ -610,6 +615,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
Expand Down
22 changes: 21 additions & 1 deletion src/io/dataset.cpp
Expand Up @@ -318,6 +318,12 @@ void Dataset::Construct(
feature_penalty_.clear();
}
}
if (!io_config.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.max_bin_by_feature.size());
CHECK(*(std::min_element(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end())) > 1);
max_bin_by_feature_.resize(num_total_features_);
max_bin_by_feature_.assign(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end());
}
max_bin_ = io_config.max_bin;
min_data_in_bin_ = io_config.min_data_in_bin;
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
Expand All @@ -332,6 +338,9 @@ void Dataset::ResetConfig(const char* parameters) {
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
if (param.count("max_bin_by_feature") && io_config.max_bin_by_feature != max_bin_by_feature_) {
Log::Warning("Cannot change max_bin_by_feature after constructed Dataset handle.");
}
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
Expand Down Expand Up @@ -643,7 +652,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(sparse_threshold_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
+ sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
+ sizeof(double) * num_features_ + sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
Expand Down Expand Up @@ -682,6 +691,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
if (max_bin_by_feature_.empty()) {
ArrayArgs<int32_t>::Assign(&max_bin_by_feature_, -1, num_total_features_);
}
writer->Write(max_bin_by_feature_.data(), sizeof(int32_t) * num_total_features_);
if (ArrayArgs<int32_t>::CheckAll(max_bin_by_feature_, -1)) {
max_bin_by_feature_.clear();
}
// write feature names
for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast<int>(feature_names_[i].size());
Expand Down Expand Up @@ -730,6 +746,10 @@ void Dataset::DumpTextFile(const char* text_filename) {
for (auto i : feature_penalty_) {
fprintf(file, "%lf, ", i);
}
fprintf(file, "\nmax_bin_by_feature: ");
for (auto i : max_bin_by_feature_) {
fprintf(file, "%d, ", i);
}
fprintf(file, "\n");
for (auto n : feature_names_) {
fprintf(file, "%s, ", n.c_str());
Expand Down
77 changes: 68 additions & 9 deletions src/io/dataset_loader.cpp
Expand Up @@ -427,6 +427,23 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
dataset->feature_penalty_.clear();
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
} else {
const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
dataset->max_bin_by_feature_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) {
dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
}
}
mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
dataset->max_bin_by_feature_.clear();
}

// get feature names
dataset->feature_names_.clear();
// write feature names
Expand Down Expand Up @@ -544,7 +561,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_.push_back(str_buf.str());
}
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
}
const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
if (Network::num_machines() == 1) {
Expand All @@ -562,8 +582,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt,
bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin_by_feature[i], config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -599,8 +627,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin_by_feature[start[rank] + i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -831,6 +867,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->num_total_features_ = static_cast<int>(feature_names_.size());
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
}

// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
Expand Down Expand Up @@ -865,8 +906,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -902,8 +951,18 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type,
config_.use_missing, config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down
23 changes: 23 additions & 0 deletions tests/python_package_test/test_engine.py
Expand Up @@ -813,6 +813,29 @@ def is_correctly_constrained(learner):
constrained_model = lgb.train(params, trainset)
self.assertTrue(is_correctly_constrained(constrained_model))

def test_max_bin_by_feature(self):
col1 = np.arange(0, 100)[:, np.newaxis]
col2 = np.zeros((100, 1))
col2[20:] = 1
X = np.concatenate([col1, col2], axis=1)
y = np.arange(0, 100)
params = {
'objective': 'regression_l2',
'verbose': -1,
'num_leaves': 100,
'min_data_in_leaf': 1,
'min_sum_hessian_in_leaf': 0,
'min_data_in_bin': 1,
'max_bin_by_feature': [100, 2]
}
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1)
self.assertEqual(len(np.unique(est.predict(X))), 100)
params['max_bin_by_feature'] = [2, 100]
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1)
self.assertEqual(len(np.unique(est.predict(X))), 3)

def test_refit(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit 291752d

Please sign in to comment.