/
dt2code.py
52 lines (45 loc) · 2 KB
/
dt2code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import re
re_parenthesis = re.compile('\([^)]+\)')
def dt2code(tree, feature_names, class_names, func_name='f'):
'''
Converting scikit-learn's DecisionTreeClassifier to Python code
Args:
<sklearn.tree.DecisionTreeClassifier> tree
<list> feature_names
<list> class_names
<str> func_name
Return:
<str> code
'''
def convert_feature_name(feature_name):
feature_name = feature_name.replace(' ', '_')
feature_name = re_parenthesis.sub('', feature_name)
return feature_name.strip('_')
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
feature_names = list(map(convert_feature_name, feature_names))
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
n_node_samples = tree.tree_.n_node_samples
def gen_code(left, right, threshold, features, n_node_samples, node, indent):
output = ''
if (threshold[node] != -2):
output += "%sif %s <= %s: # samples=%s\n" % (' ' * indent, features[node],
str(threshold[node]), n_node_samples[node])
if left[node] != -1:
output += gen_code(left, right, threshold, features, n_node_samples, left[node], indent+4)
output += "%selse:\n" % (' ' * indent)
if right[node] != -1:
output += gen_code(left, right, threshold, features, n_node_samples, right[node], indent+4)
else:
class_idx = np.argmax(value[node])
output += "%sreturn %d # samples=%s\n" % (' ' * indent, class_idx, n_node_samples[node])
return output
code = 'def %s(%s=0):\n' % (func_name, '=0, '.join(feature_names))
code += ' """\n'
code += ''.join([' %d -> %s\n' % (i, x) for (i, x) in enumerate(class_names)])
code += ' """\n'
code += gen_code(left, right, threshold, features, n_node_samples, 0, 4)
return code