Skip to content

Commit

Permalink
[MRG+1] EHN add decimals parameter for export_graphviz (scikit-learn#…
Browse files Browse the repository at this point in the history
…8698)

* EHN add decimals parameter for export_graphviz

* FIX address comments

* TST add test for classification

* TST/FIX address comments

* FIX comments raghav
  • Loading branch information
glemaitre authored and dmohns committed Aug 7, 2017
1 parent 32517e7 commit d565138
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 9 deletions.
28 changes: 22 additions & 6 deletions sklearn/tree/export.py
Expand Up @@ -11,6 +11,8 @@
# Li Li <aiki.nogard@gmail.com>
# License: BSD 3 clause

from numbers import Integral

import numpy as np
import warnings

Expand Down Expand Up @@ -73,7 +75,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
feature_names=None, class_names=None, label='all',
filled=False, leaves_parallel=False, impurity=True,
node_ids=False, proportion=False, rotate=False,
rounded=False, special_characters=False):
rounded=False, special_characters=False, precision=3):
"""Export a decision tree in DOT format.
This function generates a GraphViz representation of the decision tree,
Expand Down Expand Up @@ -143,6 +145,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
When set to ``False``, ignore special characters for PostScript
compatibility.
precision : int, optional (default=3)
Number of digits of precision for floating point in the values of
impurity, threshold and value attributes of each node.
Returns
-------
dot_data : string
Expand All @@ -162,6 +168,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
... out_file='tree.dot') # doctest: +SKIP
"""

def get_color(value):
Expand Down Expand Up @@ -226,7 +233,8 @@ def node_to_str(tree, node_id, criterion):
characters[2])
node_string += '%s %s %s%s' % (feature,
characters[3],
round(tree.threshold[node_id], 4),
round(tree.threshold[node_id],
precision),
characters[4])

# Write impurity
Expand All @@ -237,7 +245,7 @@ def node_to_str(tree, node_id, criterion):
criterion = "impurity"
if labels:
node_string += '%s = ' % criterion
node_string += (str(round(tree.impurity[node_id], 4)) +
node_string += (str(round(tree.impurity[node_id], precision)) +
characters[4])

# Write node sample count
Expand All @@ -260,16 +268,16 @@ def node_to_str(tree, node_id, criterion):
node_string += 'value = '
if tree.n_classes[0] == 1:
# Regression
value_text = np.around(value, 4)
value_text = np.around(value, precision)
elif proportion:
# Classification
value_text = np.around(value, 2)
value_text = np.around(value, precision)
elif np.all(np.equal(np.mod(value, 1), 0)):
# Classification without floating-point weights
value_text = value.astype(int)
else:
# Classification with floating-point weights
value_text = np.around(value, 4)
value_text = np.around(value, precision)
# Strip whitespace
value_text = str(value_text.astype('S32')).replace("b'", "'")
value_text = value_text.replace("' '", ", ").replace("'", "")
Expand Down Expand Up @@ -402,6 +410,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
return_string = True
out_file = six.StringIO()

if isinstance(precision, Integral):
if precision < 0:
raise ValueError("'precision' should be greater or equal to 0."
" Got {} instead.".format(precision))
else:
raise ValueError("'precision' should be an integer. Got {}"
" instead.".format(type(precision)))

# Check length of feature_names before getting into the tree node
# Raise error if length of feature_names does not match
# n_features_ in the decision_tree
Expand Down
62 changes: 59 additions & 3 deletions sklearn/tree/tests/test_export.py
Expand Up @@ -2,14 +2,18 @@
Testing for export functions of decision trees (sklearn.tree.export).
"""

from re import finditer
from re import finditer, search

from numpy.random import RandomState

from sklearn.base import ClassifierMixin
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from sklearn.utils.testing import assert_in, assert_equal, assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import (assert_in, assert_equal, assert_raises,
assert_less_equal, assert_raises_regex,
assert_raise_message)
from sklearn.exceptions import NotFittedError

# toy sample
Expand Down Expand Up @@ -235,6 +239,13 @@ def test_graphviz_errors():
out = StringIO()
assert_raises(IndexError, export_graphviz, clf, out, class_names=[])

# Check precision error
out = StringIO()
assert_raises_regex(ValueError, "should be greater or equal",
export_graphviz, clf, out, precision=-1)
assert_raises_regex(ValueError, "should be an integer",
export_graphviz, clf, out, precision="1")


def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
Expand All @@ -249,3 +260,48 @@ def test_friedman_mse_in_graphviz():

for finding in finditer("\[.*?samples.*?\]", dot_data.getvalue()):
assert_in("friedman_mse", finding.group())


def test_precision():

rng_reg = RandomState(2)
rng_clf = RandomState(8)
for X, y, clf in zip(
(rng_reg.random_sample((5, 2)),
rng_clf.random_sample((1000, 4))),
(rng_reg.random_sample((5, )),
rng_clf.randint(2, size=(1000, ))),
(DecisionTreeRegressor(criterion="friedman_mse", random_state=0,
max_depth=1),
DecisionTreeClassifier(max_depth=1, random_state=0))):

clf.fit(X, y)
for precision in (4, 3):
dot_data = export_graphviz(clf, out_file=None, precision=precision,
proportion=True)

# With the current random state, the impurity and the threshold
# will have the number of precision set in the export_graphviz
# function. We will check the number of precision with a strict
# equality. The value reported will have only 2 precision and
# therefore, only a less equal comparison will be done.

# check value
for finding in finditer("value = \d+\.\d+", dot_data):
assert_less_equal(
len(search("\.\d+", finding.group()).group()),
precision + 1)
# check impurity
if isinstance(clf, ClassifierMixin):
pattern = "gini = \d+\.\d+"
else:
pattern = "friedman_mse = \d+\.\d+"

# check impurity
for finding in finditer(pattern, dot_data):
assert_equal(len(search("\.\d+", finding.group()).group()),
precision + 1)
# check threshold
for finding in finditer("<= \d+\.\d+", dot_data):
assert_equal(len(search("\.\d+", finding.group()).group()),
precision + 1)

0 comments on commit d565138

Please sign in to comment.