diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h index f817dfabaa8..4f320698c83 100644 --- a/include/LightGBM/bin.h +++ b/include/LightGBM/bin.h @@ -456,7 +456,9 @@ class MultiValBin { inline uint32_t BinMapper::ValueToBin(double value) const { if (std::isnan(value)) { - if (missing_type_ == MissingType::NaN) { + if (bin_type_ == BinType::CategoricalBin) { + return 0; + } else if (missing_type_ == MissingType::NaN) { return num_bin_ - 1; } else { value = 0.0f; @@ -482,12 +484,12 @@ inline uint32_t BinMapper::ValueToBin(double value) const { int int_value = static_cast(value); // convert negative value to NaN bin if (int_value < 0) { - return num_bin_ - 1; + return 0; } if (categorical_2_bin_.count(int_value)) { return categorical_2_bin_.at(int_value); } else { - return num_bin_ - 1; + return 0; } } } diff --git a/src/io/bin.cpp b/src/io/bin.cpp index 367edaa3f7b..4be39040438 100644 --- a/src/io/bin.cpp +++ b/src/io/bin.cpp @@ -439,7 +439,6 @@ namespace LightGBM { } } } - num_bin_ = 0; int rest_cnt = static_cast(total_sample_cnt - na_cnt); if (rest_cnt > 0) { const int SPARSE_RATIO = 100; @@ -449,23 +448,25 @@ namespace LightGBM { } // sort by counts Common::SortForPair(&counts_int, &distinct_values_int, 0, true); - // avoid first bin is zero - if (distinct_values_int[0] == 0) { - if (counts_int.size() == 1) { - counts_int.push_back(0); - distinct_values_int.push_back(distinct_values_int[0] + 1); - } - std::swap(counts_int[0], counts_int[1]); - std::swap(distinct_values_int[0], distinct_values_int[1]); - } // will ignore the categorical of small counts - int cut_cnt = static_cast((total_sample_cnt - na_cnt) * 0.99f); + int cut_cnt = static_cast( + Common::RoundInt((total_sample_cnt - na_cnt) * 0.99f)); size_t cur_cat = 0; categorical_2_bin_.clear(); bin_2_categorical_.clear(); int used_cnt = 0; - max_bin = std::min(static_cast(distinct_values_int.size()), max_bin); + int distinct_cnt = static_cast(distinct_values_int.size()); + if (na_cnt > 0) { + ++distinct_cnt; + } + max_bin = std::min(distinct_cnt, max_bin); cnt_in_bin.clear(); + + // Push the dummy bin for NaN + bin_2_categorical_.push_back(-1); + categorical_2_bin_[-1] = 0; + cnt_in_bin.push_back(0); + num_bin_ = 1; while (cur_cat < distinct_values_int.size() && (used_cnt < cut_cnt || num_bin_ < max_bin)) { if (counts_int[cur_cat] < min_data_in_bin && cur_cat > 1) { @@ -478,21 +479,14 @@ namespace LightGBM { ++num_bin_; ++cur_cat; } - // need an additional bin for NaN - if (cur_cat == distinct_values_int.size() && na_cnt > 0) { - // use -1 to represent NaN - bin_2_categorical_.push_back(-1); - categorical_2_bin_[-1] = num_bin_; - cnt_in_bin.push_back(0); - ++num_bin_; - } // Use MissingType::None to represent this bin contains all categoricals if (cur_cat == distinct_values_int.size() && na_cnt == 0) { missing_type_ = MissingType::None; } else { missing_type_ = MissingType::NaN; } - cnt_in_bin.back() += static_cast(total_sample_cnt - used_cnt); + // fix count of NaN bin + cnt_in_bin[0] = static_cast(total_sample_cnt - used_cnt); } } @@ -511,13 +505,6 @@ namespace LightGBM { default_bin_ = ValueToBin(0); most_freq_bin_ = static_cast(ArrayArgs::ArgMax(cnt_in_bin)); - if (bin_type_ == BinType::CategoricalBin) { - if (most_freq_bin_ == 0) { - CHECK_GT(num_bin_, 1); - // FIXME: how to enable `most_freq_bin_ = 0` for categorical features - most_freq_bin_ = 1; - } - } const double max_sparse_rate = static_cast(cnt_in_bin[most_freq_bin_]) / total_sample_cnt; // When most_freq_bin_ != default_bin_, there are some additional data loading costs. diff --git a/src/io/dense_bin.hpp b/src/io/dense_bin.hpp index d61f7e6489e..122b4214019 100644 --- a/src/io/dense_bin.hpp +++ b/src/io/dense_bin.hpp @@ -318,7 +318,9 @@ class DenseBin : public Bin { data_size_t gt_count = 0; data_size_t* default_indices = gt_indices; data_size_t* default_count = >_count; - if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { + int8_t offset = most_freq_bin == 0 ? 1 : 0; + if (most_freq_bin > 0 && + Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { default_indices = lte_indices; default_count = <e_count; } @@ -330,7 +332,7 @@ class DenseBin : public Bin { } else if (!USE_MIN_BIN && bin == 0) { default_indices[(*default_count)++] = idx; } else if (Common::FindInBitset(threshold, num_threahold, - bin - min_bin)) { + bin - min_bin + offset)) { lte_indices[lte_count++] = idx; } else { gt_indices[gt_count++] = idx; diff --git a/src/io/sparse_bin.hpp b/src/io/sparse_bin.hpp index aa3ed929713..431f6cc3cda 100644 --- a/src/io/sparse_bin.hpp +++ b/src/io/sparse_bin.hpp @@ -364,7 +364,8 @@ class SparseBin : public Bin { data_size_t* default_indices = gt_indices; data_size_t* default_count = >_count; SparseBinIterator iterator(this, data_indices[0]); - if (Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { + int8_t offset = most_freq_bin == 0 ? 1 : 0; + if (most_freq_bin > 0 && Common::FindInBitset(threshold, num_threahold, most_freq_bin)) { default_indices = lte_indices; default_count = <e_count; } @@ -376,7 +377,7 @@ class SparseBin : public Bin { } else if (!USE_MIN_BIN && bin == 0) { default_indices[(*default_count)++] = idx; } else if (Common::FindInBitset(threshold, num_threahold, - bin - min_bin)) { + bin - min_bin + offset)) { lte_indices[lte_count++] = idx; } else { gt_indices[gt_count++] = idx; diff --git a/src/treelearner/feature_histogram.hpp b/src/treelearner/feature_histogram.hpp index 8916ee48fd4..7fa341c67ac 100644 --- a/src/treelearner/feature_histogram.hpp +++ b/src/treelearner/feature_histogram.hpp @@ -300,8 +300,10 @@ class FeatureHistogram { } double min_gain_shift = gain_shift + meta_->config->min_gain_to_split; - bool is_full_categorical = meta_->missing_type == MissingType::None; - int used_bin = meta_->num_bin - 1 + is_full_categorical; + const int8_t offset = meta_->offset; + const int bin_start = 1 - offset; + const int bin_end = meta_->num_bin - offset; + int used_bin = -1; std::vector sorted_idx; double l2 = meta_->config->lambda_l2; @@ -312,11 +314,11 @@ class FeatureHistogram { int rand_threshold = 0; if (use_onehot) { if (USE_RAND) { - if (used_bin > 0) { - rand_threshold = meta_->rand.NextInt(0, used_bin); + if (bin_end - bin_start > 0) { + rand_threshold = meta_->rand.NextInt(bin_start, bin_end); } } - for (int t = 0; t < used_bin; ++t) { + for (int t = bin_start; t < bin_end; ++t) { const auto grad = GET_GRAD(data_, t); const auto hess = GET_HESS(data_, t); data_size_t cnt = @@ -366,7 +368,7 @@ class FeatureHistogram { } } } else { - for (int i = 0; i < used_bin; ++i) { + for (int i = bin_start; i < bin_end; ++i) { if (Common::RoundInt(GET_HESS(data_, i) * cnt_factor) >= meta_->config->cat_smooth) { sorted_idx.push_back(i); @@ -379,11 +381,11 @@ class FeatureHistogram { auto ctr_fun = [this](double sum_grad, double sum_hess) { return (sum_grad) / (sum_hess + meta_->config->cat_smooth); }; - std::sort(sorted_idx.begin(), sorted_idx.end(), - [this, &ctr_fun](int i, int j) { - return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) < - ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j)); - }); + std::stable_sort( + sorted_idx.begin(), sorted_idx.end(), [this, &ctr_fun](int i, int j) { + return ctr_fun(GET_GRAD(data_, i), GET_HESS(data_, i)) < + ctr_fun(GET_GRAD(data_, j), GET_HESS(data_, j)); + }); std::vector find_direction(1, 1); std::vector start_position(1, 0); @@ -489,19 +491,19 @@ class FeatureHistogram { if (use_onehot) { output->num_cat_threshold = 1; output->cat_threshold = - std::vector(1, static_cast(best_threshold)); + std::vector(1, static_cast(best_threshold + offset)); } else { output->num_cat_threshold = best_threshold + 1; output->cat_threshold = std::vector(output->num_cat_threshold); if (best_dir == 1) { for (int i = 0; i < output->num_cat_threshold; ++i) { - auto t = sorted_idx[i]; + auto t = sorted_idx[i] + offset; output->cat_threshold[i] = t; } } else { for (int i = 0; i < output->num_cat_threshold; ++i) { - auto t = sorted_idx[used_bin - 1 - i]; + auto t = sorted_idx[used_bin - 1 - i] + offset; output->cat_threshold[i] = t; } } @@ -649,16 +651,14 @@ class FeatureHistogram { double gain_shift = GetLeafGainGivenOutput( sum_gradient, sum_hessian, meta_->config->lambda_l1, meta_->config->lambda_l2, parent_output); double min_gain_shift = gain_shift + meta_->config->min_gain_to_split; - bool is_full_categorical = meta_->missing_type == MissingType::None; - int used_bin = meta_->num_bin - 1 + is_full_categorical; - if (threshold >= static_cast(used_bin)) { + if (threshold >= static_cast(meta_->num_bin) || threshold == 0) { output->gain = kMinScore; Log::Warning("Invalid categorical threshold split"); return; } const double cnt_factor = num_data / sum_hessian; - const auto grad = GET_GRAD(data_, threshold); - const auto hess = GET_HESS(data_, threshold); + const auto grad = GET_GRAD(data_, threshold - meta_->offset); + const auto hess = GET_HESS(data_, threshold - meta_->offset); data_size_t cnt = static_cast(Common::RoundInt(hess * cnt_factor));