Skip to content
Permalink
Browse files

Cleanup naming in Python code

  • Loading branch information...
jonnor committed Oct 20, 2018
1 parent a9fd7af commit b1d94945b5d0f09e85eff57465fd520f9270d54c
Showing with 31 additions and 21 deletions.
  1. +1 −1 MANIFEST.in
  2. +5 −2 emlearn/__init__.py
  3. +8 −0 emlearn/common.py
  4. +11 −0 emlearn/convert.py
  5. +2 −15 emlearn/{randomforest.py → trees.py}
  6. +3 −2 examples/digits.py
  7. +1 −1 test/test_trees.py
@@ -1,3 +1,3 @@
include LICENSE.md
include requirements*.txt
include emtrees/*.h
include emlearn/*.h
@@ -1,5 +1,8 @@

from .randomforest import convert
from . import trees
from . import common

includedir = randomforest.get_include_dir()
from .convert import convert

includedir = common.get_include_dir()

@@ -0,0 +1,8 @@

import os
import os.path

def get_include_dir():
return os.path.join(os.path.dirname(__file__))


@@ -0,0 +1,11 @@

from . import trees

def convert(estimator, kind=None, method='pymodule'):
if kind is None:
kind = type(estimator).__name__

if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
return trees.Wrapper(estimator, method)
else:
raise ValueError("Unknown model type: '{}'".format(kind))
@@ -8,13 +8,10 @@

import numpy

from . import common
import emtreesc


def get_include_dir():
return os.path.join(os.path.dirname(__file__))


# Tree representation as 2d array
# feature, value, left_child, right_child
# Leaf node: -1, class, -1, -1
@@ -378,7 +375,7 @@ def run_classifier(bin_path, data):
class CompiledClassifier():
def __init__(self, cmodel, name, call=None, include_dir=None, temp_dir='tmp/'):
if include_dir == None:
include_dir = get_include_dir()
include_dir = common.get_include_dir()
self.bin_path = build_classifier(cmodel, name, include_dir=include_dir, temp_dir=temp_dir, func=call)

def predict(self, X):
@@ -432,13 +429,3 @@ def to_dot(self, **kwargs):
return forest_to_dot(self.forest_, **kwargs)


def convert(estimator, kind=None, method='pymodule'):
if kind is None:
kind = type(estimator).__name__

if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
return Wrapper(estimator, method)
else:
raise ValueError("Unknown model type: '{}'".format(kind))


@@ -1,5 +1,6 @@

import emtrees
import emlearn

import numpy
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
@@ -28,7 +29,7 @@
m = numpy.max(Xtrain), numpy.min(Xtrain)

filename = 'digits.h'
cmodel = emtrees.convert(model)
cmodel = emlearn.convert(model)
code = cmodel.save(file=filename)

print('Wrote C code to', filename)
@@ -55,7 +55,7 @@ def test_deduplicate_single_tree():
]
roots = [ 6 ]

de_nodes, de_roots = emlearn.randomforest.remove_duplicate_leaves((nodes, roots))
de_nodes, de_roots = emlearn.trees.remove_duplicate_leaves((nodes, roots))

duplicates = 1
assert len(de_roots) == len(roots)

0 comments on commit b1d9494

Please sign in to comment.
You can’t perform that action at this time.