Skip to content

Commit

Permalink
[python] added get_split_value_histogram method (#2041)
Browse files Browse the repository at this point in the history
* added get_split_value_histogram method

* added param for ordinary return value
  • Loading branch information
StrikerRUS authored and guolinke committed Mar 9, 2019
1 parent 9de526f commit 8d6666e
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 1 deletion.
64 changes: 63 additions & 1 deletion python-package/lightgbm/basic.py
Expand Up @@ -13,7 +13,8 @@
import numpy as np
import scipy.sparse

from .compat import (DataFrame, Series, DataTable,
from .compat import (PANDAS_INSTALLED, DataFrame, Series,
DataTable,
decode_string, string_type,
integer_types, numeric_types,
json, json_default_with_numpy,
Expand Down Expand Up @@ -2427,6 +2428,67 @@ def feature_importance(self, importance_type='split', iteration=None):
else:
return result

def get_split_value_histogram(self, feature, bins=None, xgboost_style=False):
"""Get split value histogram for the specified feature.
Parameters
----------
feature : int or string
The feature name or index the histogram is calculated for.
If int, interpreted as index.
If string, interpreted as name.
bins : int, string or None, optional (default=None)
The maximum number of bins.
If None, or int and > number of unique split values and ``xgboost_style=True``,
the number of bins equals number of unique split values.
If string, it should be one from the list of the supported values by ``numpy.histogram()`` function.
xgboost_style : bool, optional (default=False)
Whether the returned result should be in the same form as it is in XGBoost.
If False, the returned value is tuple of 2 numpy arrays as it is in ``numpy.histogram()`` function.
If True, the returned value is matrix, in which the first column is the right edges of non-empty bins
and the second one is the histogram values.
Returns
-------
result_tuple : tuple of 2 numpy arrays
If ``xgboost_style=False``, the values of the histogram of used splitting values for the specified feature
and the bin edges.
result_array_like : numpy array or pandas DataFrame (if pandas is installed)
If ``xgboost_style=True``, the histogram of used splitting values for the specified feature.
"""
def add(root):
"""Recursively add thresholds."""
if 'split_index' in root: # non-leaf
if feature_names is not None and isinstance(feature, string_type):
split_feature = feature_names[root['split_feature']]
else:
split_feature = root['split_feature']
if split_feature == feature:
values.append(root['threshold'])
add(root['left_child'])
add(root['right_child'])

model = self.dump_model()
feature_names = model.get('feature_names')
tree_infos = model['tree_info']
values = []
for tree_info in tree_infos:
add(tree_info['tree_structure'])

if bins is None or isinstance(bins, integer_types) and xgboost_style:
n_unique = len(np.unique(values))
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
hist, bin_edges = np.histogram(values, bins=bins)
if xgboost_style:
ret = np.column_stack((bin_edges[1:], hist))
ret = ret[ret[:, 1] > 0]
if PANDAS_INSTALLED:
return DataFrame(ret, columns=['SplitValue', 'Count'])
else:
return ret
else:
return hist, bin_edges

def __inner_eval(self, data_name, data_idx, feval=None):
"""Evaluate training or validation data."""
if data_idx >= self.__num_dataset:
Expand Down
74 changes: 74 additions & 0 deletions tests/python_package_test/test_engine.py
Expand Up @@ -1242,3 +1242,77 @@ def test_model_size(self):
np.testing.assert_allclose(y_pred, y_pred_new)
except MemoryError:
self.skipTest('not enough RAM')

def test_get_split_value_histogram(self):
X, y = load_boston(True)
lgb_train = lgb.Dataset(X, y)
gbm = lgb.train({'verbose': -1}, lgb_train, num_boost_round=20)
# test XGBoost-style return value
params = {'feature': 0, 'xgboost_style': True}
self.assertTupleEqual(gbm.get_split_value_histogram(**params).shape, (10, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=999, **params).shape, (10, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=-1, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=0, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=1, **params).shape, (1, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=2, **params).shape, (2, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=6, **params).shape, (6, 2))
self.assertTupleEqual(gbm.get_split_value_histogram(bins=7, **params).shape, (6, 2))
if lgb.compat.PANDAS_INSTALLED:
np.testing.assert_almost_equal(
gbm.get_split_value_histogram(0, xgboost_style=True).values,
gbm.get_split_value_histogram(gbm.feature_name()[0], xgboost_style=True).values
)
np.testing.assert_almost_equal(
gbm.get_split_value_histogram(X.shape[-1] - 1, xgboost_style=True).values,
gbm.get_split_value_histogram(gbm.feature_name()[X.shape[-1] - 1], xgboost_style=True).values
)
else:
np.testing.assert_almost_equal(
gbm.get_split_value_histogram(0, xgboost_style=True),
gbm.get_split_value_histogram(gbm.feature_name()[0], xgboost_style=True)
)
np.testing.assert_almost_equal(
gbm.get_split_value_histogram(X.shape[-1] - 1, xgboost_style=True),
gbm.get_split_value_histogram(gbm.feature_name()[X.shape[-1] - 1], xgboost_style=True)
)
# test numpy-style return value
hist, bins = gbm.get_split_value_histogram(0)
self.assertEqual(len(hist), 22)
self.assertEqual(len(bins), 23)
hist, bins = gbm.get_split_value_histogram(0, bins=999)
self.assertEqual(len(hist), 999)
self.assertEqual(len(bins), 1000)
self.assertRaises(ValueError, gbm.get_split_value_histogram, 0, bins=-1)
self.assertRaises(ValueError, gbm.get_split_value_histogram, 0, bins=0)
hist, bins = gbm.get_split_value_histogram(0, bins=1)
self.assertEqual(len(hist), 1)
self.assertEqual(len(bins), 2)
hist, bins = gbm.get_split_value_histogram(0, bins=2)
self.assertEqual(len(hist), 2)
self.assertEqual(len(bins), 3)
hist, bins = gbm.get_split_value_histogram(0, bins=6)
self.assertEqual(len(hist), 6)
self.assertEqual(len(bins), 7)
hist, bins = gbm.get_split_value_histogram(0, bins=7)
self.assertEqual(len(hist), 7)
self.assertEqual(len(bins), 8)
hist_idx, bins_idx = gbm.get_split_value_histogram(0)
hist_name, bins_name = gbm.get_split_value_histogram(gbm.feature_name()[0])
np.testing.assert_array_equal(hist_idx, hist_name)
np.testing.assert_almost_equal(bins_idx, bins_name)
hist_idx, bins_idx = gbm.get_split_value_histogram(X.shape[-1] - 1)
hist_name, bins_name = gbm.get_split_value_histogram(gbm.feature_name()[X.shape[-1] - 1])
np.testing.assert_array_equal(hist_idx, hist_name)
np.testing.assert_almost_equal(bins_idx, bins_name)
# test bins string type
if np.__version__ > '1.11.0':
hist_vals, bin_edges = gbm.get_split_value_histogram(0, bins='auto')
hist = gbm.get_split_value_histogram(0, bins='auto', xgboost_style=True)
if lgb.compat.PANDAS_INSTALLED:
mask = hist_vals > 0
np.testing.assert_array_equal(hist_vals[mask], hist['Count'].values)
np.testing.assert_almost_equal(bin_edges[1:][mask], hist['SplitValue'].values)
else:
mask = hist_vals > 0
np.testing.assert_array_equal(hist_vals[mask], hist[:, 1])
np.testing.assert_almost_equal(bin_edges[1:][mask], hist[:, 0])

0 comments on commit 8d6666e

Please sign in to comment.