Skip to content

Commit

Permalink
[python] added plot_split_value_histogram function (#2043)
Browse files Browse the repository at this point in the history
* added plot_split_value_histogram function

* updated init module

* added plot split value histogram example

* added plot_split_value_histogram to notebook

* added test

* fixed pylint

* updated API docs

* fixed grammar

* set y ticks to int value in more sufficient way
  • Loading branch information
StrikerRUS authored and henry0312 committed May 1, 2019
1 parent 65c7779 commit 611cf5d
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/Python-API.rst
Expand Up @@ -62,6 +62,8 @@ Plotting

.. autofunction:: lightgbm.plot_importance

.. autofunction:: lightgbm.plot_split_value_histogram

.. autofunction:: lightgbm.plot_metric

.. autofunction:: lightgbm.plot_tree
Expand Down
1 change: 1 addition & 0 deletions examples/python-guide/README.md
Expand Up @@ -57,5 +57,6 @@ Examples include:
- Train and record eval results for further plotting
- Plot metrics recorded during training
- Plot feature importances
- Plot split value histogram
- Plot one specified tree
- Plot one specified tree with Graphviz
50 changes: 48 additions & 2 deletions examples/python-guide/notebooks/interactive_plot_example.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions examples/python-guide/plot_example.py
Expand Up @@ -50,6 +50,10 @@
ax = lgb.plot_importance(gbm, max_num_features=10)
plt.show()

print('Plotting split value histogram...')
ax = lgb.plot_split_value_histogram(gbm, feature='f26', bins='auto')
plt.show()

print('Plotting 54th tree...') # one tree use categorical feature to split
ax = lgb.plot_tree(gbm, tree_index=53, figsize=(15, 15), show_info=['split_gain'])
plt.show()
Expand Down
5 changes: 3 additions & 2 deletions python-package/lightgbm/__init__.py
Expand Up @@ -19,7 +19,8 @@
except ImportError:
pass
try:
from .plotting import plot_importance, plot_metric, plot_tree, create_tree_digraph
from .plotting import (plot_importance, plot_split_value_histogram, plot_metric,
plot_tree, create_tree_digraph)
except ImportError:
pass

Expand All @@ -34,7 +35,7 @@
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'plot_importance', 'plot_metric', 'plot_tree', 'create_tree_digraph']
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']

# REMOVEME: remove warning after 2.3.0 version release
if system() == 'Darwin':
Expand Down
104 changes: 104 additions & 0 deletions python-package/lightgbm/plotting.py
Expand Up @@ -141,6 +141,110 @@ def plot_importance(booster, ax=None, height=0.2,
return ax


def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef=0.8,
xlim=None, ylim=None,
title='Split value histogram for feature with @index/name@ @feature@',
xlabel='Feature split value', ylabel='Count',
figsize=None, grid=True, **kwargs):
"""Plot split value histogram for the specified feature of the model.
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance of which feature split value histogram should be plotted.
feature : int or string
The feature name or index the histogram is plotted for.
If int, interpreted as index.
If string, interpreted as name.
bins : int, string or None, optional (default=None)
The maximum number of bins.
If None, the number of bins equals number of unique split values.
If string, it should be one from the list of the supported values by ``numpy.histogram()`` function.
ax : matplotlib.axes.Axes or None, optional (default=None)
Target axes instance.
If None, new figure and axes will be created.
width_coef : float, optional (default=0.8)
Coefficient for histogram bar width.
xlim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.xlim()``.
ylim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.ylim()``.
title : string or None, optional (default="Split value histogram for feature with @index/name@ @feature@")
Axes title.
If None, title is disabled.
@feature@ placeholder can be used, and it will be replaced with the value of ``feature`` parameter.
@index/name@ placeholder can be used,
and it will be replaced with ``index`` word in case of ``int`` type ``feature`` parameter
or ``name`` word in case of ``string`` type ``feature`` parameter.
xlabel : string or None, optional (default="Feature split value")
X-axis title label.
If None, title is disabled.
ylabel : string or None, optional (default="Count")
Y-axis title label.
If None, title is disabled.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
**kwargs
Other parameters passed to ``ax.bar()``.
Returns
-------
ax : matplotlib.axes.Axes
The plot with specified model's feature split value histogram.
"""
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
else:
raise ImportError('You must install matplotlib to plot split value histogram.')

if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')

hist, bins = booster.get_split_value_histogram(feature=feature, bins=bins, xgboost_style=False)
if np.count_nonzero(hist) == 0:
raise ValueError('Cannot plot split value histogram, '
'because feature {} was not used in splitting'.format(feature))
width = width_coef * (bins[1] - bins[0])
centred = (bins[:-1] + bins[1:]) / 2

if ax is None:
if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)

ax.bar(centred, hist, align='center', width=width, **kwargs)

if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
range_result = bins[-1] - bins[0]
xlim = (bins[0] - range_result * 0.2, bins[-1] + range_result * 0.2)
ax.set_xlim(xlim)

ax.yaxis.set_major_locator(MaxNLocator(integer=True))
if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
ylim = (0, max(hist) * 1.1)
ax.set_ylim(ylim)

if title is not None:
title = title.replace('@feature@', str(feature))
title = title.replace('@index/name@', ('name' if isinstance(feature, string_type) else 'index'))
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax


def plot_metric(booster, metric=None, dataset_names=None,
ax=None, xlim=None, ylim=None,
title='Metric during training',
Expand Down
39 changes: 39 additions & 0 deletions tests/python_package_test/test_plotting.py
Expand Up @@ -60,6 +60,45 @@ def test_plot_importance(self):
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_split_value_histogram(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_split_value_histogram(gbm0, 27)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Split value histogram for feature with index 27')
self.assertEqual(ax0.get_xlabel(), 'Feature split value')
self.assertEqual(ax0.get_ylabel(), 'Count')
self.assertLessEqual(len(ax0.patches), 2)

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(self.X_train, self.y_train)

ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5),
title='Histogram for feature @index/name@ @feature@',
xlabel='x', ylabel='y', color='r')
self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(),
'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27]))
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
self.assertLessEqual(len(ax1.patches), 2)
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red

ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
self.assertEqual(len(ax2.patches), 10)
self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r
self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b

self.assertRaises(ValueError, lgb.plot_split_value_histogram, gbm0, 0) # was not used in splitting

@unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
def test_plot_tree(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
Expand Down

0 comments on commit 611cf5d

Please sign in to comment.