Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max bin by feature #2190

Merged
merged 7 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,12 @@ IO Parameters

- LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``

- ``max_bin_by_feature`` :raw-html:`<a id="max_bin_by_feature" title="Permalink to this parameter" href="#max_bin_by_feature">&#x1F517;&#xFE0E;</a>`, default = ``None``, type = multi-int

- max number of bins for each feature

- if not specified, will use ``max_bin`` for all features

- ``min_data_in_bin`` :raw-html:`<a id="min_data_in_bin" title="Permalink to this parameter" href="#min_data_in_bin">&#x1F517;&#xFE0E;</a>`, default = ``3``, type = int, constraints: ``min_data_in_bin > 0``

- minimal number of data inside one bin
Expand Down
6 changes: 6 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,12 @@ struct Config {
// desc = LightGBM will auto compress memory according to ``max_bin``. For example, LightGBM will use ``uint8_t`` for feature value if ``max_bin=255``
int max_bin = 255;

// type = multi-int
// default = None
// desc = max number of bins for each feature
// desc = if not specified, will use ``max_bin`` for all features
std::vector<int32_t> max_bin_by_feature;

// check = >0
// desc = minimal number of data inside one bin
// desc = use this to avoid one-data-one-bin (potential over-fitting)
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ class Dataset {
std::vector<double> feature_penalty_;
bool is_finish_load_;
int max_bin_;
std::vector<int32_t> max_bin_by_feature_;
int bin_construct_sample_cnt_;
int min_data_in_bin_;
bool use_missing_;
Expand Down
6 changes: 6 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ std::unordered_set<std::string> Config::parameter_set({
"cegb_penalty_feature_coupled",
"verbosity",
"max_bin",
"max_bin_by_feature",
"min_data_in_bin",
"bin_construct_sample_cnt",
"histogram_pool_size",
Expand Down Expand Up @@ -401,6 +402,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetInt(params, "max_bin", &max_bin);
CHECK(max_bin >1);

if (GetString(params, "max_bin_by_feature", &tmp_str)) {
max_bin_by_feature = Common::StringToArray<int32_t>(tmp_str, ',');
}

GetInt(params, "min_data_in_bin", &min_data_in_bin);
CHECK(min_data_in_bin >0);

Expand Down Expand Up @@ -588,6 +593,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
str_buf << "[bin_construct_sample_cnt: " << bin_construct_sample_cnt << "]\n";
str_buf << "[histogram_pool_size: " << histogram_pool_size << "]\n";
Expand Down
22 changes: 21 additions & 1 deletion src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ void Dataset::Construct(
feature_penalty_.clear();
}
}
if (!io_config.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.max_bin_by_feature.size());
CHECK(*(std::min_element(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end())) > 1);
max_bin_by_feature_.resize(num_total_features_);
max_bin_by_feature_.assign(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end());
}
max_bin_ = io_config.max_bin;
min_data_in_bin_ = io_config.min_data_in_bin;
bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt;
Expand All @@ -332,6 +338,9 @@ void Dataset::ResetConfig(const char* parameters) {
if (param.count("max_bin") && io_config.max_bin != max_bin_) {
Log::Warning("Cannot change max_bin after constructed Dataset handle.");
}
if (param.count("max_bin_by_feature") && io_config.max_bin_by_feature != max_bin_by_feature_) {
Log::Warning("Cannot change max_bin_by_feature after constructed Dataset handle.");
}
if (param.count("bin_construct_sample_cnt") && io_config.bin_construct_sample_cnt != bin_construct_sample_cnt_) {
Log::Warning("Cannot change bin_construct_sample_cnt after constructed Dataset handle.");
}
Expand Down Expand Up @@ -643,7 +652,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(sparse_threshold_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_
+ sizeof(double) * num_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
+ sizeof(double) * num_features_ + sizeof(int32_t) * num_total_features_ + sizeof(int) * 3 + sizeof(bool) * 2;
// size of feature names
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
Expand Down Expand Up @@ -682,6 +691,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
if (ArrayArgs<double>::CheckAll(feature_penalty_, 1.0)) {
feature_penalty_.clear();
}
if (max_bin_by_feature_.empty()) {
ArrayArgs<int32_t>::Assign(&max_bin_by_feature_, -1, num_total_features_);
}
writer->Write(max_bin_by_feature_.data(), sizeof(int32_t) * num_total_features_);
if (ArrayArgs<int32_t>::CheckAll(max_bin_by_feature_, -1)) {
max_bin_by_feature_.clear();
}
// write feature names
for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast<int>(feature_names_[i].size());
Expand Down Expand Up @@ -730,6 +746,10 @@ void Dataset::DumpTextFile(const char* text_filename) {
for (auto i : feature_penalty_) {
fprintf(file, "%lf, ", i);
}
fprintf(file, "\nmax_bin_by_feature: ");
for (auto i : max_bin_by_feature_) {
fprintf(file, "%d, ", i);
}
fprintf(file, "\n");
for (auto n : feature_names_) {
fprintf(file, "%s, ", n.c_str());
Expand Down
77 changes: 68 additions & 9 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,23 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
dataset->feature_penalty_.clear();
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
} else {
const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
dataset->max_bin_by_feature_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) {
dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
}
}
mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
dataset->max_bin_by_feature_.clear();
}

// get feature names
dataset->feature_names_.clear();
// write feature names
Expand Down Expand Up @@ -544,7 +561,10 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_.push_back(str_buf.str());
}
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
}
const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
if (Network::num_machines() == 1) {
Expand All @@ -562,8 +582,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt,
bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
config_.max_bin_by_feature[i], config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -599,8 +627,16 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size,
config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
total_sample_size, config_.max_bin_by_feature[start[rank] + i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -830,6 +866,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
dataset->num_total_features_ = static_cast<int>(feature_names_.size());
}

if (!config_.max_bin_by_feature.empty()) {
CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
}

// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
Expand Down Expand Up @@ -864,8 +905,16 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down Expand Up @@ -901,8 +950,18 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
if (config_.max_bin_by_feature.empty()) {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin, config_.min_data_in_bin,
filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing);
} else {
bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
static_cast<int>(sample_values[start[rank] + i].size()),
sample_data.size(), config_.max_bin_by_feature[i],
config_.min_data_in_bin, filter_cnt, bin_type,
config_.use_missing, config_.zero_as_missing);
}
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
Expand Down
23 changes: 23 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,29 @@ def is_correctly_constrained(learner):
constrained_model = lgb.train(params, trainset)
self.assertTrue(is_correctly_constrained(constrained_model))

def test_max_bin_by_feature(self):
col1 = np.arange(0, 100)[:, np.newaxis]
col2 = np.zeros((100, 1))
col2[20:] = 1
X = np.concatenate([col1, col2], axis=1)
y = np.arange(0, 100)
params = {
'objective': 'regression_l2',
'verbose': -1,
'num_leaves': 100,
'min_data_in_leaf': 1,
'min_sum_hessian_in_leaf': 0,
'min_data_in_bin': 1,
'max_bin_by_feature': [100, 2],
}
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1)
self.assertEqual(len(np.unique(est.predict(X))), 100)
params['max_bin_by_feature'] = [2, 100]
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data)
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(len(np.unique(est.predict(X))),3)

def test_refit(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down