Skip to content

Commit

Permalink
fix #520
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed May 16, 2017
1 parent e984b0d commit 353286a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 31 deletions.
5 changes: 4 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ struct IOConfig: public ConfigBase {
int min_data_in_bin = 5;
double max_conflict_rate = 0.0f;
bool enable_bundle = true;
bool adjacent_bundle = false;
bool has_header = false;
/*! \brief Index or column name of label, default is the first column
* And add an prefix "name:" while using column name */
Expand All @@ -135,7 +134,11 @@ struct IOConfig: public ConfigBase {
* And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */
std::string categorical_column = "";
std::string device_type = "cpu";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetDeviceType(const std::unordered_map<std::string,
std::string>& params);
};

/*! \brief Config for objective function */
Expand Down
18 changes: 14 additions & 4 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,22 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
CHECK(min_data_in_bin > 0);
GetDouble(params, "max_conflict_rate", &max_conflict_rate);
GetBool(params, "enable_bundle", &enable_bundle);
GetBool(params, "adjacent_bundle", &adjacent_bundle);
GetDeviceType(params);
}

void IOConfig::GetDeviceType(const std::unordered_map<std::string, std::string>& params) {
std::string value;
if (GetString(params, "device", &value)) {
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("cpu")) {
device_type = "cpu";
} else if (value == std::string("gpu")) {
device_type = "gpu";
} else {
Log::Fatal("Unknown device type %s", value.c_str());
}
}
}

void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "is_unbalance", &is_unbalance);
Expand Down Expand Up @@ -357,9 +370,6 @@ void BoostingConfig::GetTreeLearnerType(const std::unordered_map<std::string, st
std::transform(value.begin(), value.end(), value.begin(), Common::tolower);
if (value == std::string("serial")) {
tree_learner_type = "serial";
} else if (value == std::string("gpu")) {
tree_learner_type = "serial";
device_type = "gpu";
} else if (value == std::string("feature") || value == std::string("feature_parallel")) {
tree_learner_type = "feature";
} else if (value == std::string("data") || value == std::string("data_parallel")) {
Expand Down
45 changes: 19 additions & 26 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@

namespace LightGBM {

#ifdef USE_GPU
const int kMaxBinPerGroup = 256;
#endif // USE_GPU

const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______\n";

Dataset::Dataset() {
Expand Down Expand Up @@ -65,38 +61,34 @@ void MarkUsed(std::vector<bool>& mark, const int* indices, int num_indices) {
}
}


std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
const std::vector<int>& find_order,
int** sample_indices,
const int* num_per_col,
size_t total_sample_cnt,
data_size_t max_error_cnt,
data_size_t filter_cnt,
data_size_t num_data) {
data_size_t num_data,
bool is_use_gpu) {
const int max_search_group = 100;
const int gpu_max_bin_per_group = 256;
Random rand(num_data);
std::vector<std::vector<int>> features_in_group;
std::vector<std::vector<bool>> conflict_marks;
std::vector<int> group_conflict_cnt;
std::vector<size_t> group_non_zero_cnt;

#ifdef USE_GPU
std::vector<int> group_num_bin;
#endif // USE_GPU

for (auto fidx : find_order) {
const size_t cur_non_zero_cnt = num_per_col[fidx];
bool need_new_group = true;
std::vector<int> available_groups;
for (int gid = 0; gid < static_cast<int>(features_in_group.size()); ++gid) {
if (group_non_zero_cnt[gid] + cur_non_zero_cnt <= total_sample_cnt + max_error_cnt
#ifdef USE_GPU
&& group_num_bin[gid] + bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0)
<= kMaxBinPerGroup
#endif // USE_GPU
) {
available_groups.push_back(gid);
if (group_non_zero_cnt[gid] + cur_non_zero_cnt <= total_sample_cnt + max_error_cnt){
if (!is_use_gpu || group_num_bin[gid] + bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0)
<= gpu_max_bin_per_group) {
available_groups.push_back(gid);
}
}
}
std::vector<int> search_groups;
Expand All @@ -120,9 +112,9 @@ std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMa
group_conflict_cnt[gid] += cnt;
group_non_zero_cnt[gid] += cur_non_zero_cnt - cnt;
MarkUsed(conflict_marks[gid], sample_indices[fidx], num_per_col[fidx]);
#ifdef USE_GPU
group_num_bin[gid] += bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0);
#endif // USE_GPU
if (is_use_gpu) {
group_num_bin[gid] += bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0);
}
break;
}
}
Expand All @@ -133,9 +125,9 @@ std::vector<std::vector<int>> FindGroups(const std::vector<std::unique_ptr<BinMa
conflict_marks.emplace_back(total_sample_cnt, false);
MarkUsed(conflict_marks.back(), sample_indices[fidx], num_per_col[fidx]);
group_non_zero_cnt.emplace_back(cur_non_zero_cnt);
#ifdef USE_GPU
group_num_bin.push_back(1 + bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0));
#endif // USE_GPU
if (is_use_gpu) {
group_num_bin.push_back(1 + bin_mappers[fidx]->num_bin() + (bin_mappers[fidx]->GetDefaultBin() == 0 ? -1 : 0));
}
}
}
return features_in_group;
Expand All @@ -150,7 +142,8 @@ std::vector<std::vector<int>> FastFeatureBundling(std::vector<std::unique_ptr<Bi
data_size_t num_data,
data_size_t min_data,
double sparse_threshold,
bool is_enable_sparse) {
bool is_enable_sparse,
bool is_use_gpu) {
// filter is based on sampling data, so decrease its range
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * min_data) / num_data * total_sample_cnt);
const data_size_t max_error_cnt = static_cast<data_size_t>(total_sample_cnt * max_conflict_rate);
Expand All @@ -176,8 +169,8 @@ std::vector<std::vector<int>> FastFeatureBundling(std::vector<std::unique_ptr<Bi
for (auto sidx : sorted_idx) {
feature_order_by_cnt.push_back(used_features[sidx]);
}
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, num_per_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data);
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, num_per_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data);
auto features_in_group = FindGroups(bin_mappers, used_features, sample_indices, num_per_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
auto group2 = FindGroups(bin_mappers, feature_order_by_cnt, sample_indices, num_per_col, total_sample_cnt, max_error_cnt, filter_cnt, num_data, is_use_gpu);
if (features_in_group.size() > group2.size()) {
features_in_group = group2;
}
Expand Down Expand Up @@ -239,7 +232,7 @@ void Dataset::Construct(
sample_non_zero_indices, num_per_col, total_sample_cnt,
used_features, io_config.max_conflict_rate,
num_data_, io_config.min_data_in_leaf,
sparse_threshold_, io_config.is_enable_sparse);
sparse_threshold_, io_config.is_enable_sparse, io_config.device_type == std::string("gpu"));
}

num_features_ = 0;
Expand Down
1 change: 1 addition & 0 deletions src/io/sparse_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class SparseBin: public Bin {
const data_size_t cur_idx = idx_val_pairs[i].first;
const VAL_T bin = idx_val_pairs[i].second;
data_size_t cur_delta = cur_idx - last_idx;
if (i > 0 && cur_delta == 0) { continue; }
while (cur_delta >= 256) {
deltas_.push_back(cur_delta & 0xff);
vals_.push_back(0);
Expand Down

0 comments on commit 353286a

Please sign in to comment.