Skip to content

Commit

Permalink
fix zero bin in categorical split (#3305)
Browse files Browse the repository at this point in the history
* fix zero bin

* some fix

* fix bin mapping

* fix

* fix bug

* use stable sort

* fix cat forced split

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review
  • Loading branch information
guolinke committed Aug 15, 2020
1 parent 27c9aa8 commit 0391076
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 54 deletions.
8 changes: 5 additions & 3 deletions include/LightGBM/bin.h
Expand Up @@ -456,7 +456,9 @@ class MultiValBin {

inline uint32_t BinMapper::ValueToBin(double value) const {
if (std::isnan(value)) {
if (missing_type_ == MissingType::NaN) {
if (bin_type_ == BinType::CategoricalBin) {
return 0;
} else if (missing_type_ == MissingType::NaN) {
return num_bin_ - 1;
} else {
value = 0.0f;
Expand All @@ -482,12 +484,12 @@ inline uint32_t BinMapper::ValueToBin(double value) const {
int int_value = static_cast<int>(value);
// convert negative value to NaN bin
if (int_value < 0) {
return num_bin_ - 1;
return 0;
}
if (categorical_2_bin_.count(int_value)) {
return categorical_2_bin_.at(int_value);
} else {
return num_bin_ - 1;
return 0;
}
}
}
Expand Down
43 changes: 15 additions & 28 deletions src/io/bin.cpp
Expand Up @@ -439,7 +439,6 @@ namespace LightGBM {
}
}
}
num_bin_ = 0;
int rest_cnt = static_cast<int>(total_sample_cnt - na_cnt);
if (rest_cnt > 0) {
const int SPARSE_RATIO = 100;
Expand All @@ -449,23 +448,25 @@ namespace LightGBM {
}
// sort by counts
Common::SortForPair<int, int>(&counts_int, &distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
// will ignore the categorical of small counts
int cut_cnt = static_cast<int>((total_sample_cnt - na_cnt) * 0.99f);
int cut_cnt = static_cast<int>(
Common::RoundInt((total_sample_cnt - na_cnt) * 0.99f));
size_t cur_cat = 0;
categorical_2_bin_.clear();
bin_2_categorical_.clear();
int used_cnt = 0;
max_bin = std::min(static_cast<int>(distinct_values_int.size()), max_bin);
int distinct_cnt = static_cast<int>(distinct_values_int.size());
if (na_cnt > 0) {
++distinct_cnt;
}
max_bin = std::min(distinct_cnt, max_bin);
cnt_in_bin.clear();

// Push the dummy bin for NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = 0;
cnt_in_bin.push_back(0);
num_bin_ = 1;
while (cur_cat < distinct_values_int.size()
&& (used_cnt < cut_cnt || num_bin_ < max_bin)) {
if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) {
Expand All @@ -478,21 +479,14 @@ namespace LightGBM {
++num_bin_;
++cur_cat;
}
// need an additional bin for NaN
if (cur_cat == distinct_values_int.size() && na_cnt > 0) {
// use -1 to represent NaN
bin_2_categorical_.push_back(-1);
categorical_2_bin_[-1] = num_bin_;
cnt_in_bin.push_back(0);
++num_bin_;
}
// Use MissingType::None to represent this bin contains all categoricals
if (cur_cat == distinct_values_int.size() && na_cnt == 0) {
missing_type_ = MissingType::None;
} else {
missing_type_ = MissingType::NaN;
}
cnt_in_bin.back() += static_cast<int>(total_sample_cnt - used_cnt);
// fix count of NaN bin
cnt_in_bin[0] = static_cast<int>(total_sample_cnt - used_cnt);
}
}

Expand All @@ -511,13 +505,6 @@ namespace LightGBM {
default_bin_ = ValueToBin(0);
most_freq_bin_ =
static_cast<uint32_t>(ArrayArgs<int>::ArgMax(cnt_in_bin));
if (bin_type_ == BinType::CategoricalBin) {
if (most_freq_bin_ == 0) {
CHECK_GT(num_bin_, 1);
// FIXME: how to enable `most_freq_bin_ = 0` for categorical features
most_freq_bin_ = 1;
}
}
const double max_sparse_rate =
static_cast<double>(cnt_in_bin[most_freq_bin_]) / total_sample_cnt;
// When most_freq_bin_ != default_bin_, there are some additional data loading costs.
Expand Down
6 changes: 4 additions & 2 deletions src/io/dense_bin.hpp
Expand Up @@ -318,7 +318,9 @@ class DenseBin : public Bin {
data_size_t gt_count = 0;
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 &&
Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
Expand All @@ -330,7 +332,7 @@ class DenseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) {
bin - min_bin + offset)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
Expand Down
5 changes: 3 additions & 2 deletions src/io/sparse_bin.hpp
Expand Up @@ -364,7 +364,8 @@ class SparseBin : public Bin {
data_size_t* default_indices = gt_indices;
data_size_t* default_count = &gt_count;
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
int8_t offset = most_freq_bin == 0 ? 1 : 0;
if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threahold, most_freq_bin)) {
default_indices = lte_indices;
default_count = &lte_count;
}
Expand All @@ -376,7 +377,7 @@ class SparseBin : public Bin {
} else if (!USE_MIN_BIN && bin == 0) {
default_indices[(*default_count)++] = idx;
} else if (Common::FindInBitset(threshold, num_threahold,
bin - min_bin)) {
bin - min_bin + offset)) {
lte_indices[lte_count++] = idx;
} else {
gt_indices[gt_count++] = idx;
Expand Down
38 changes: 19 additions & 19 deletions src/treelearner/feature_histogram.hpp
Expand Up @@ -300,8 +300,10 @@ class FeatureHistogram {
}

double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
const int8_t offset = meta_->offset;
const int bin_start = 1 - offset;
const int bin_end = meta_->num_bin - offset;
int used_bin = -1;

std::vector<int> sorted_idx;
double l2 = meta_->config->lambda_l2;
Expand All @@ -312,11 +314,11 @@ class FeatureHistogram {
int rand_threshold = 0;
if (use_onehot) {
if (USE_RAND) {
if (used_bin > 0) {
rand_threshold = meta_->rand.NextInt(0, used_bin);
if (bin_end - bin_start > 0) {
rand_threshold = meta_->rand.NextInt(bin_start, bin_end);
}
}
for (int t = 0; t < used_bin; ++t) {
for (int t = bin_start; t < bin_end; ++t) {
const auto grad = GET_GRAD(data_, t);
const auto hess = GET_HESS(data_, t);
data_size_t cnt =
Expand Down Expand Up @@ -366,7 +368,7 @@ class FeatureHistogram {
}
}
} else {
for (int i = 0; i < used_bin; ++i) {
for (int i = bin_start; i < bin_end; ++i) {
if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >=
meta_->config->cat_smooth) {
sorted_idx.push_back(i);
Expand All @@ -379,11 +381,11 @@ class FeatureHistogram {
auto ctr_fun = [this](double sum_grad, double sum_hess) {
return (sum_grad) / (sum_hess + meta_->config->cat_smooth);
};
std::sort(sorted_idx.begin(), sorted_idx.end(),
[this, &ctr_fun](int i, int j) {
return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
});
std::stable_sort(
sorted_idx.begin(), sorted_idx.end(), [this, &ctr_fun](int i, int j) {
return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) <
ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j));
});

std::vector<int> find_direction(1, 1);
std::vector<int> start_position(1, 0);
Expand Down Expand Up @@ -489,19 +491,19 @@ class FeatureHistogram {
if (use_onehot) {
output->num_cat_threshold = 1;
output->cat_threshold =
std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold));
std::vector<uint32_t>(1, static_cast<uint32_t>(best_threshold + offset));
} else {
output->num_cat_threshold = best_threshold + 1;
output->cat_threshold =
std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[i];
auto t = sorted_idx[i] + offset;
output->cat_threshold[i] = t;
}
} else {
for (int i = 0; i < output->num_cat_threshold; ++i) {
auto t = sorted_idx[used_bin - 1 - i];
auto t = sorted_idx[used_bin - 1 - i] + offset;
output->cat_threshold[i] = t;
}
}
Expand Down Expand Up @@ -649,16 +651,14 @@ class FeatureHistogram {
double gain_shift = GetLeafGainGivenOutput<true>(
sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output);
double min_gain_shift = gain_shift + meta_->config->min_gain_to_split;
bool is_full_categorical = meta_->missing_type == MissingType::None;
int used_bin = meta_->num_bin - 1 + is_full_categorical;
if (threshold >= static_cast<uint32_t>(used_bin)) {
if (threshold >= static_cast<uint32_t>(meta_->num_bin) || threshold == 0) {
output->gain = kMinScore;
Log::Warning("Invalid categorical threshold split");
return;
}
const double cnt_factor = num_data / sum_hessian;
const auto grad = GET_GRAD(data_, threshold);
const auto hess = GET_HESS(data_, threshold);
const auto grad = GET_GRAD(data_, threshold - meta_->offset);
const auto hess = GET_HESS(data_, threshold - meta_->offset);
data_size_t cnt =
static_cast<data_size_t>(Common::RoundInt(hess * cnt_factor));

Expand Down

0 comments on commit 0391076

Please sign in to comment.