Skip to content

Commit

Permalink
fix #833
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 15, 2017
1 parent 549bbdf commit e99f986
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/io/bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,32 +140,36 @@ std::vector<double> GreedyFindBin(const double* distinct_values, const int* coun
return bin_upper_bound;
}

std::vector<double> FindBinWithZeroAsMissing(const double* distinct_values, const int* counts,
std::vector<double> FindBinWithZeroAsOneBin(const double* distinct_values, const int* counts,
int num_distinct_values, int max_bin, size_t total_sample_cnt, int min_data_in_bin) {
std::vector<double> bin_upper_bound;
int left_cnt_data = 0;
int cnt_missing = 0;
int cnt_zero = 0;
int right_cnt_data = 0;
for (int i = 0; i < num_distinct_values; ++i) {
if (distinct_values[i] <= -kZeroAsMissingValueRange) {
left_cnt_data += counts[i];
} else if (distinct_values[i] > kZeroAsMissingValueRange) {
right_cnt_data += counts[i];
} else {
cnt_missing += counts[i];
cnt_zero += counts[i];
}
}

int left_cnt = 0;
int left_cnt = -1;
for (int i = 0; i < num_distinct_values; ++i) {
if (distinct_values[i] > -kZeroAsMissingValueRange) {
left_cnt = i;
break;
}
}

if (left_cnt < 0) {
left_cnt = num_distinct_values;
}

if (left_cnt > 0) {
int left_max_bin = static_cast<int>(static_cast<double>(left_cnt_data) / (total_sample_cnt - cnt_missing) * (max_bin - 1));
int left_max_bin = static_cast<int>(static_cast<double>(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1));
bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin);
bin_upper_bound.back() = -kZeroAsMissingValueRange;
}
Expand Down Expand Up @@ -257,14 +261,14 @@ void BinMapper::FindBin(double* values, int num_sample_values, size_t total_samp
int num_distinct_values = static_cast<int>(distinct_values.size());
if (bin_type_ == BinType::NumericalBin) {
if (missing_type_ == MissingType::Zero) {
bin_upper_bound_ = FindBinWithZeroAsMissing(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin);
bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin);
if (bin_upper_bound_.size() == 2) {
missing_type_ = MissingType::None;
}
} else if (missing_type_ == MissingType::None) {
bin_upper_bound_ = FindBinWithZeroAsMissing(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin);
bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin);
} else {
bin_upper_bound_ = FindBinWithZeroAsMissing(distinct_values.data(), counts.data(), num_distinct_values, max_bin - 1, total_sample_cnt - na_cnt, min_data_in_bin);
bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin - 1, total_sample_cnt - na_cnt, min_data_in_bin);
bin_upper_bound_.push_back(NaN);
}
num_bin_ = static_cast<int>(bin_upper_bound_.size());
Expand Down

0 comments on commit e99f986

Please sign in to comment.