Skip to content
Permalink
Browse files

trees: Also support DecisionTreeClassifier

Fixes #6
  • Loading branch information...
jonnor committed Feb 23, 2019
1 parent 8489654 commit a482202238fe38d7b8652592393feb7aececa5d1
Showing with 9 additions and 5 deletions.
  1. +1 −1 emlearn/convert.py
  2. +6 −1 emlearn/trees.py
  3. +2 −3 test/test_trees.py
@@ -8,7 +8,7 @@ def convert(estimator, kind=None, method='pymodule'):
kind = type(estimator).__name__

# Uname instead of instance to avoid hard dependency on the libraries
if kind in ['RandomForestClassifier', 'ExtraTreesClassifier']:
if kind in ['RandomForestClassifier', 'ExtraTreesClassifier', 'DecisionTreeClassifier']:
return trees.Wrapper(estimator, method)
elif kind == 'MLPClassifier':
return net.convert_sklearn_mlp(estimator, method)
@@ -338,7 +338,12 @@ def generate_c_forest(forest, name='myclassifier'):
class Wrapper:
def __init__(self, estimator, classifier):

self.forest_ = flatten_forest([ e.tree_ for e in estimator.estimators_])
if hasattr(estimator, 'estimators_'):
trees = [ e.tree_ for e in estimator.estimators_]
else:
trees = [ estimator.tree_ ]

self.forest_ = flatten_forest(trees)
self.forest_ = remove_duplicate_leaves(self.forest_)

if classifier == 'pymodule':
@@ -4,10 +4,8 @@
import numpy.testing

from sklearn import datasets
from sklearn import model_selection
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn import metrics
from sklearn.utils.estimator_checks import check_estimator
from sklearn.tree import DecisionTreeClassifier

import emlearn
import pytest
@@ -18,6 +16,7 @@
MODELS = {
'RFC': RandomForestClassifier(n_estimators=10, random_state=random),
'ETC': ExtraTreesClassifier(n_estimators=10, random_state=random),
'DTC': DecisionTreeClassifier(random_state=random),
}
DATASETS = {
'binary': datasets.make_classification(n_classes=2, n_samples=100, random_state=random),

0 comments on commit a482202

Please sign in to comment.
You can’t perform that action at this time.