Skip to content

Commit b1d9494

Browse files
committed
Cleanup naming in Python code
1 parent a9fd7af commit b1d9494

File tree

7 files changed

+31
-21
lines changed

7 files changed

+31
-21
lines changed

MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -1,3 +1,3 @@
1
include LICENSE.md
1
include LICENSE.md
2
include requirements*.txt
2
include requirements*.txt
3-
include emtrees/*.h
3+
include emlearn/*.h

emlearn/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -1,5 +1,8 @@
1

1

2-
from .randomforest import convert
2+
from . import trees
3+
from . import common
3

4

4-
includedir = randomforest.get_include_dir()
5+
from .convert import convert
6+
7+
includedir = common.get_include_dir()
5

8

emlearn/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
import os
3+
import os.path
4+
5+
def get_include_dir():
6+
return os.path.join(os.path.dirname(__file__))
7+
8+

emlearn/convert.py

Lines changed: 11 additions & 0 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
2+
from . import trees
3+
4+
def convert(estimator, kind=None, method='pymodule'):
5+
if kind is None:
6+
kind = type(estimator).__name__
7+
8+
if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
9+
return trees.Wrapper(estimator, method)
10+
else:
11+
raise ValueError("Unknown model type: '{}'".format(kind))

emlearn/randomforest.py renamed to emlearn/trees.py

Lines changed: 2 additions & 15 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -8,13 +8,10 @@
8

8

9
import numpy
9
import numpy
10

10

11+
from . import common
11
import emtreesc
12
import emtreesc
12

13

13

14

14-
def get_include_dir():
15-
return os.path.join(os.path.dirname(__file__))
16-
17-
18
# Tree representation as 2d array
15
# Tree representation as 2d array
19
# feature, value, left_child, right_child
16
# feature, value, left_child, right_child
20
# Leaf node: -1, class, -1, -1
17
# Leaf node: -1, class, -1, -1
@@ -378,7 +375,7 @@ def run_classifier(bin_path, data):
378
class CompiledClassifier():
375
class CompiledClassifier():
379
def __init__(self, cmodel, name, call=None, include_dir=None, temp_dir='tmp/'):
376
def __init__(self, cmodel, name, call=None, include_dir=None, temp_dir='tmp/'):
380
if include_dir == None:
377
if include_dir == None:
381-
include_dir = get_include_dir()
378+
include_dir = common.get_include_dir()
382
self.bin_path = build_classifier(cmodel, name, include_dir=include_dir, temp_dir=temp_dir, func=call)
379
self.bin_path = build_classifier(cmodel, name, include_dir=include_dir, temp_dir=temp_dir, func=call)
383

380

384
def predict(self, X):
381
def predict(self, X):
@@ -432,13 +429,3 @@ def to_dot(self, **kwargs):
432
return forest_to_dot(self.forest_, **kwargs)
429
return forest_to_dot(self.forest_, **kwargs)
433

430

434

431

435-
def convert(estimator, kind=None, method='pymodule'):
436-
if kind is None:
437-
kind = type(estimator).__name__
438-
439-
if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
440-
return Wrapper(estimator, method)
441-
else:
442-
raise ValueError("Unknown model type: '{}'".format(kind))
443-
444-

examples/digits.py

Lines changed: 3 additions & 2 deletions
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -1,5 +1,6 @@
1

1

2-
import emtrees
2+
import emlearn
3+
3
import numpy
4
import numpy
4
from sklearn.model_selection import train_test_split
5
from sklearn.model_selection import train_test_split
5
from sklearn.ensemble import RandomForestClassifier
6
from sklearn.ensemble import RandomForestClassifier
@@ -28,7 +29,7 @@
28
m = numpy.max(Xtrain), numpy.min(Xtrain)
29
m = numpy.max(Xtrain), numpy.min(Xtrain)
29

30

30
filename = 'digits.h'
31
filename = 'digits.h'
31-
cmodel = emtrees.convert(model)
32+
cmodel = emlearn.convert(model)
32
code = cmodel.save(file=filename)
33
code = cmodel.save(file=filename)
33

34

34
print('Wrote C code to', filename)
35
print('Wrote C code to', filename)

test/test_trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberOriginal file lineDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_deduplicate_single_tree():
55
]
55
]
56
roots = [ 6 ]
56
roots = [ 6 ]
57

57

58-
de_nodes, de_roots = emlearn.randomforest.remove_duplicate_leaves((nodes, roots))
58+
de_nodes, de_roots = emlearn.trees.remove_duplicate_leaves((nodes, roots))
59

59

60
duplicates = 1
60
duplicates = 1
61
assert len(de_roots) == len(roots)
61
assert len(de_roots) == len(roots)

0 commit comments

Comments
 (0)