Skip to content

Commit

Permalink
Cleanup naming in Python code
Browse files Browse the repository at this point in the history
  • Loading branch information
jonnor committed Oct 20, 2018
1 parent a9fd7af commit b1d9494
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
@@ -1,3 +1,3 @@
include LICENSE.md include LICENSE.md
include requirements*.txt include requirements*.txt
include emtrees/*.h include emlearn/*.h
7 changes: 5 additions & 2 deletions emlearn/__init__.py
@@ -1,5 +1,8 @@


from .randomforest import convert from . import trees
from . import common


includedir = randomforest.get_include_dir() from .convert import convert

includedir = common.get_include_dir()


8 changes: 8 additions & 0 deletions emlearn/common.py
@@ -0,0 +1,8 @@

import os
import os.path

def get_include_dir():
return os.path.join(os.path.dirname(__file__))


11 changes: 11 additions & 0 deletions emlearn/convert.py
@@ -0,0 +1,11 @@

from . import trees

def convert(estimator, kind=None, method='pymodule'):
if kind is None:
kind = type(estimator).__name__

if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
return trees.Wrapper(estimator, method)
else:
raise ValueError("Unknown model type: '{}'".format(kind))
17 changes: 2 additions & 15 deletions emlearn/randomforest.py → emlearn/trees.py
Expand Up @@ -8,13 +8,10 @@


import numpy import numpy


from . import common
import emtreesc import emtreesc




def get_include_dir():
return os.path.join(os.path.dirname(__file__))


# Tree representation as 2d array # Tree representation as 2d array
# feature, value, left_child, right_child # feature, value, left_child, right_child
# Leaf node: -1, class, -1, -1 # Leaf node: -1, class, -1, -1
Expand Down Expand Up @@ -378,7 +375,7 @@ def run_classifier(bin_path, data):
class CompiledClassifier(): class CompiledClassifier():
def __init__(self, cmodel, name, call=None, include_dir=None, temp_dir='tmp/'): def __init__(self, cmodel, name, call=None, include_dir=None, temp_dir='tmp/'):
if include_dir == None: if include_dir == None:
include_dir = get_include_dir() include_dir = common.get_include_dir()
self.bin_path = build_classifier(cmodel, name, include_dir=include_dir, temp_dir=temp_dir, func=call) self.bin_path = build_classifier(cmodel, name, include_dir=include_dir, temp_dir=temp_dir, func=call)


def predict(self, X): def predict(self, X):
Expand Down Expand Up @@ -432,13 +429,3 @@ def to_dot(self, **kwargs):
return forest_to_dot(self.forest_, **kwargs) return forest_to_dot(self.forest_, **kwargs)




def convert(estimator, kind=None, method='pymodule'):
if kind is None:
kind = type(estimator).__name__

if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
return Wrapper(estimator, method)
else:
raise ValueError("Unknown model type: '{}'".format(kind))


5 changes: 3 additions & 2 deletions examples/digits.py
@@ -1,5 +1,6 @@


import emtrees import emlearn

import numpy import numpy
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -28,7 +29,7 @@
m = numpy.max(Xtrain), numpy.min(Xtrain) m = numpy.max(Xtrain), numpy.min(Xtrain)


filename = 'digits.h' filename = 'digits.h'
cmodel = emtrees.convert(model) cmodel = emlearn.convert(model)
code = cmodel.save(file=filename) code = cmodel.save(file=filename)


print('Wrote C code to', filename) print('Wrote C code to', filename)
Expand Down
2 changes: 1 addition & 1 deletion test/test_trees.py
Expand Up @@ -55,7 +55,7 @@ def test_deduplicate_single_tree():
] ]
roots = [ 6 ] roots = [ 6 ]


de_nodes, de_roots = emlearn.randomforest.remove_duplicate_leaves((nodes, roots)) de_nodes, de_roots = emlearn.trees.remove_duplicate_leaves((nodes, roots))


duplicates = 1 duplicates = 1
assert len(de_roots) == len(roots) assert len(de_roots) == len(roots)
Expand Down

0 comments on commit b1d9494

Please sign in to comment.