Skip to content

Commit

Permalink
fix #991 (#992)
Browse files Browse the repository at this point in the history
* refine categorical split

* a bug fix

* fix a bug
  • Loading branch information
guolinke committed Oct 13, 2017
1 parent 4aa3296 commit ef22127
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/io/bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,11 @@ namespace LightGBM {
// sort by counts
Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
// avoid first bin is zero
if (distinct_values_int[0] == 0 && counts_int.size() > 1) {
if (distinct_values_int[0] == 0 || (counts_int.size() == 1 && na_cnt > 0)) {
if (counts_int.size() == 1) {
counts_int.push_back(0);
distinct_values_int.push_back(distinct_values_int[0] + 1);
}
std::swap(counts_int[0], counts_int[1]);
std::swap(distinct_values_int[0], distinct_values_int[1]);
}
Expand Down
10 changes: 8 additions & 2 deletions src/treelearner/feature_histogram.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,17 @@ class FeatureHistogram {
output->cat_threshold = std::vector<uint32_t>(output->num_cat_threshold);
if (best_dir == 1) {
for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[i];
auto t = sorted_idx[i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
}
} else {
for (int i = 0; i < output->num_cat_threshold; ++i) {
output->cat_threshold[i] = sorted_idx[used_bin - 1 - i];
auto t = sorted_idx[used_bin - 1 - i];
if (data_[t].cnt > 0) {
output->cat_threshold[i] = t;
}
}
}
}
Expand Down
62 changes: 62 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,68 @@ def test_categorical_handle(self):
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)

def test_categorical_handle2(self):
x = [0, np.nan, 0, np.nan, 0, np.nan]
y = [0, 1, 0, 1, 0, 1]

X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)

params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)

def test_categorical_handle3(self):
x = [11, np.nan, 11, np.nan, 11, np.nan]
y = [0, 1, 0, 1, 0, 1]

X_train = np.array(x).reshape(len(x), 1)
y_train = np.array(y)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)

params = {
'objective': 'regression',
'metric': 'auc',
'verbose': -1,
'boost_from_average': False,
'min_data': 1,
'num_leaves': 2,
'learning_rate': 1,
'min_data_in_bin': 1,
'min_data_per_group': 1,
'zero_as_missing': False,
'categorical_column': 0
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=1,
valid_sets=lgb_eval,
verbose_eval=True,
evals_result=evals_result)
pred = gbm.predict(X_train)
np.testing.assert_almost_equal(pred, y)

def test_multiclass(self):
X, y = load_digits(10, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit ef22127

Please sign in to comment.