Skip to content

Commit

Permalink
[python] disable default pandas cat features if cat features were exp…
Browse files Browse the repository at this point in the history
…licitly provided (#2121)

* disable default pandas cat features if cat features were explicitly provided

* added assertion for cat features
  • Loading branch information
StrikerRUS committed Apr 22, 2019
1 parent d115769 commit 4be53a5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
8 changes: 4 additions & 4 deletions python-package/lightgbm/basic.py
Expand Up @@ -274,18 +274,18 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
if categorical_feature is not None:
if feature_name is None:
feature_name = list(data.columns)
if categorical_feature == 'auto':
if categorical_feature == 'auto': # use cat cols from DataFrame
categorical_feature = cat_cols_not_ordered
else:
categorical_feature = list(categorical_feature) + cat_cols_not_ordered
else: # use cat cols specified by user
categorical_feature = list(categorical_feature)
if feature_name == 'auto':
feature_name = list(data.columns)
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
bad_fields = [data.columns[i] for i, dtype in
enumerate(data_dtypes) if dtype.name not in PANDAS_DTYPE_MAPPER]
raise ValueError("DataFrame.dtypes for data must be int, float or bool.\n"
"Did not expect the data types in fields "
"Did not expect the data types in the following fields: "
+ ', '.join(bad_fields))
data = data.values.astype('float')
else:
Expand Down
25 changes: 22 additions & 3 deletions tests/python_package_test/test_engine.py
Expand Up @@ -553,7 +553,7 @@ def test_template(init_model=None, return_model=False):
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
import pandas as pd
np.random.seed(42) # sometimes there is no difference how E col is treated (cat or not cat)
np.random.seed(42) # sometimes there is no difference how cols are treated (cat or not cat)
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
Expand Down Expand Up @@ -581,18 +581,23 @@ def test_pandas_categorical(self):
lgb_train = lgb.Dataset(X, y)
gbm0 = lgb.train(params, lgb_train, num_boost_round=10)
pred0 = gbm0.predict(X_test)
self.assertEqual(lgb_train.categorical_feature, 'auto')
lgb_train = lgb.Dataset(X, pd.DataFrame(y)) # also test that label can be one-column pd.DataFrame
gbm1 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=[0])
pred1 = gbm1.predict(X_test)
self.assertListEqual(lgb_train.categorical_feature, [0])
lgb_train = lgb.Dataset(X, pd.Series(y)) # also test that label can be pd.Series
gbm2 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=['A'])
pred2 = gbm2.predict(X_test)
self.assertListEqual(lgb_train.categorical_feature, ['A'])
lgb_train = lgb.Dataset(X, y)
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=['A', 'B', 'C', 'D'])
pred3 = gbm3.predict(X_test)
self.assertListEqual(lgb_train.categorical_feature, ['A', 'B', 'C', 'D'])
gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = gbm4.predict(X_test)
self.assertListEqual(lgb_train.categorical_feature, ['A', 'B', 'C', 'D'])
model_str = gbm4.model_to_string()
gbm4.model_from_string(model_str, False)
pred5 = gbm4.predict(X_test)
Expand All @@ -601,22 +606,36 @@ def test_pandas_categorical(self):
lgb_train = lgb.Dataset(X, y)
gbm6 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=['E'])
pred7 = gbm6.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
self.assertListEqual(lgb_train.categorical_feature, ['E'])
lgb_train = lgb.Dataset(X, y)
gbm7 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=[])
pred8 = gbm7.predict(X_test)
self.assertListEqual(lgb_train.categorical_feature, [])
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred1)
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred2)
np.testing.assert_almost_equal(pred1, pred2)
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred0, pred4)
np.testing.assert_almost_equal(pred0, pred5)
np.testing.assert_almost_equal(pred0, pred6)
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred7) # ordered cat features aren't treated as cat features by default
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred8)
self.assertListEqual(gbm0.pandas_categorical, cat_values)
self.assertListEqual(gbm1.pandas_categorical, cat_values)
self.assertListEqual(gbm2.pandas_categorical, cat_values)
self.assertListEqual(gbm3.pandas_categorical, cat_values)
self.assertListEqual(gbm4.pandas_categorical, cat_values)
self.assertListEqual(gbm5.pandas_categorical, cat_values)
self.assertListEqual(gbm6.pandas_categorical, cat_values)
self.assertListEqual(gbm7.pandas_categorical, cat_values)

def test_reference_chain(self):
X = np.random.normal(size=(100, 2))
Expand Down
27 changes: 19 additions & 8 deletions tests/python_package_test/test_sklearn.py
Expand Up @@ -206,7 +206,7 @@ def test_sklearn_integration(self):
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
import pandas as pd
np.random.seed(42) # sometimes there is no difference how E col is treated (cat or not cat)
np.random.seed(42) # sometimes there is no difference how cols are treated (cat or not cat)
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
Expand All @@ -227,32 +227,43 @@ def test_pandas_categorical(self):
X_test[cat_cols_actual] = X_test[cat_cols_actual].astype('category')
cat_values = [X[col].cat.categories.tolist() for col in cat_cols_to_store]
gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y)
pred0 = gbm0.predict(X_test)
pred0 = gbm0.predict(X_test, raw_score=True)
pred_prob = gbm0.predict_proba(X_test)[:, 1]
gbm1 = lgb.sklearn.LGBMClassifier().fit(X, pd.Series(y), categorical_feature=[0])
pred1 = gbm1.predict(X_test)
pred1 = gbm1.predict(X_test, raw_score=True)
gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A'])
pred2 = gbm2.predict(X_test)
pred2 = gbm2.predict(X_test, raw_score=True)
gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D'])
pred3 = gbm3.predict(X_test)
pred3 = gbm3.predict(X_test, raw_score=True)
gbm3.booster_.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = gbm4.predict(X_test)
gbm5 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['E'])
pred5 = gbm5.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
pred5 = gbm5.predict(X_test, raw_score=True)
gbm6 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=[])
pred6 = gbm6.predict(X_test, raw_score=True)
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred1)
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred2)
np.testing.assert_almost_equal(pred1, pred2)
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred_prob, pred4)
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred5) # ordered cat features aren't treated as cat features by default
self.assertRaises(AssertionError,
np.testing.assert_almost_equal,
pred0, pred6)
self.assertListEqual(gbm0.booster_.pandas_categorical, cat_values)
self.assertListEqual(gbm1.booster_.pandas_categorical, cat_values)
self.assertListEqual(gbm2.booster_.pandas_categorical, cat_values)
self.assertListEqual(gbm3.booster_.pandas_categorical, cat_values)
self.assertListEqual(gbm4.pandas_categorical, cat_values)
self.assertListEqual(gbm5.booster_.pandas_categorical, cat_values)
self.assertListEqual(gbm6.booster_.pandas_categorical, cat_values)

def test_predict(self):
iris = load_iris()
Expand Down

0 comments on commit 4be53a5

Please sign in to comment.