# sklearn-porter

Repository: https://github.com/nok/sklearn-porter

## ExtraTreesClassifier

Documentation: [sklearn.ensemble.ExtraTreesClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html)

### Loading data:

In [1]:
from sklearn.datasets import load_iris

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)

((150, 4), (150,))


### Train classifier:

In [2]:
from sklearn.ensemble import ExtraTreesClassifier

clf = ExtraTreesClassifier(n_estimators=15, max_depth=None,
                           min_samples_split=2, random_state=0)
clf.fit(X, y)

ExtraTreesClassifier(bootstrap=False, class_weight=None, criterion='gini',
           max_depth=None, max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=15, n_jobs=1,
           oob_score=False, random_state=0, verbose=0, warm_start=False)

### Transpile classifier:

In [4]:
%%time

from sklearn_porter import Porter

porter = Porter(clf)
output = porter.export(export_data=True)

print(output)

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Scanner;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;


class ExtraTreesClassifier {

    private class Tree {
        private int[] childrenLeft;
        private int[] childrenRight;
        private double[] thresholds;
        private int[] indices;
        private double[][] classes;

        private int predict (double[] features, int node) {
            if (this.thresholds[node] != -2) {
                if (features[this.indices[node]] <= this.thresholds[node]) {
                    return this.predict(features, this.childrenLeft[node]);
                } else {
                    return this.predict(features, this.childrenRight[node]);
                }
            }
            return ExtraTreesClassifier.findMax(this.classes[node]);
        }
        private int predict (double[] features) {
    

Parameters:

In [5]:
%%bash

cat data.json

[{"indices": [0, 1, 2, -2, 3, -2, -2, 3, -2, -2, 3, 2, 3, -2, 3, 1, -2, -2, 3, -2, 2, -2, 3, 3, 2, -2, 2, -2, -2, -2, 1, -2, -2, -2, -2], "thresholds": [5.52112836752, 2.54236914672, 2.32859402389, -2.0, 1.4511234333, -2.0, -2.0, 0.83878367201, -2.0, -2.0, 1.89369306498, 5.18076136748, 0.25425931195, -2.0, 1.29791099198, 3.14667550028, -2.0, -2.0, 1.49177428996, -2.0, 4.64861094171, -2.0, 1.78027764001, 1.50659268982, 4.89970329741, -2.0, 4.98785335085, -2.0, -2.0, -2.0, 3.12447121782, -2.0, -2.0, -2.0, -2.0], "classes": [[50.0, 50.0, 50.0], [47.0, 11.0, 1.0], [1.0, 8.0, 1.0], [1.0, 0.0, 0.0], [0.0, 8.0, 1.0], [0.0, 8.0, 0.0], [0.0, 0.0, 1.0], [46.0, 3.0, 0.0], [46.0, 0.0, 0.0], [0.0, 3.0, 0.0], [3.0, 39.0, 49.0], [3.0, 39.0, 15.0], [3.0, 39.0, 7.0], [1.0, 0.0, 0.0], [2.0, 39.0, 7.0], [2.0, 8.0, 0.0], [0.0, 8.0, 0.0], [2.0, 0.0, 0.0], [0.0, 31.0, 7.0], [0.0, 17.0, 0.0], [0.0, 14.0, 7.0], [0.0, 7.0, 0.0], [0.0, 7.0, 7.0], [0.0, 6.0, 2.0], [0.0, 3.0, 2.0], [0.0, 1.0, 0.0], [0.0, 2.0, 2.0

### Run classification in Java:

Save the transpiled estimator:

In [7]:
with open('ExtraTreesClassifier.java', 'w') as f:
    f.write(output)

Download the dependencies:

In [8]:
%%bash

wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar

--2017-11-26 22:02:04--  http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
Resolving central.maven.org... 151.101.36.209
Connecting to central.maven.org|151.101.36.209|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 232932 (227K) [application/java-archive]
Saving to: 'gson-2.8.2.jar'

     0K .......... .......... .......... .......... .......... 21% 1.85M 0s
    50K .......... .......... .......... .......... .......... 43% 3.92M 0s
   100K .......... .......... .......... .......... .......... 65% 3.51M 0s
   150K .......... .......... .......... .......... .......... 87% 3.16M 0s
   200K .......... .......... .......                         100% 5.85M=0.07s

2017-11-26 22:02:04 (3.05 MB/s) - 'gson-2.8.2.jar' saved [232932/232932]



Compiling:

In [9]:
%%bash

javac -cp .:gson-2.8.2.jar ExtraTreesClassifier.java

Prediction:

In [10]:
%%bash

java -cp .:gson-2.8.2.jar ExtraTreesClassifier data.json 1 2 3 4

2
