# sklearn-porter

Repository: https://github.com/nok/sklearn-porter

## RandomForestClassifier

Documentation: [sklearn.ensemble.RandomForestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)

### Loading data:

In [1]:
from sklearn.datasets import load_iris

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)

((150, 4), (150,))


### Train classifier:

In [2]:
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators=15, max_depth=None,
                             min_samples_split=2, random_state=0)
clf.fit(X, y)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=15, n_jobs=1,
            oob_score=False, random_state=0, verbose=0, warm_start=False)

### Transpile classifier:

In [4]:
%%time

from sklearn_porter import Porter

porter = Porter(clf)
output = porter.export(export_data=True)

print(output)

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Scanner;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;


class RandomForestClassifier {

    private class Tree {
        private int[] childrenLeft;
        private int[] childrenRight;
        private double[] thresholds;
        private int[] indices;
        private double[][] classes;

        private int predict (double[] features, int node) {
            if (this.thresholds[node] != -2) {
                if (features[this.indices[node]] <= this.thresholds[node]) {
                    return this.predict(features, this.childrenLeft[node]);
                } else {
                    return this.predict(features, this.childrenRight[node]);
                }
            }
            return RandomForestClassifier.findMax(this.classes[node]);
        }
        private int predict (double[] features) {


Parameters:

In [5]:
%%bash

cat data.json

[{"indices": [3, -2, 2, 3, -2, 1, -2, -2, 0, -2, 2, -2, -2], "thresholds": [0.75, -2.0, 4.85000038147, 1.65000009537, -2.0, 3.0, -2.0, -2.0, 6.59999990463, -2.0, 5.19999980927, -2.0, -2.0], "classes": [[47.0, 44.0, 59.0], [47.0, 0.0, 0.0], [0.0, 44.0, 59.0], [0.0, 43.0, 3.0], [0.0, 42.0, 0.0], [0.0, 1.0, 3.0], [0.0, 0.0, 3.0], [0.0, 1.0, 0.0], [0.0, 1.0, 56.0], [0.0, 0.0, 27.0], [0.0, 1.0, 29.0], [0.0, 1.0, 0.0], [0.0, 0.0, 29.0]], "childrenRight": [2, -1, 8, 5, -1, 7, -1, -1, 10, -1, 12, -1, -1], "childrenLeft": [1, -1, 3, 4, -1, 6, -1, -1, 9, -1, 11, -1, -1]}, {"indices": [3, -2, 3, 2, -2, 2, 1, -2, -2, -2, 2, 1, -2, -2, -2], "thresholds": [0.800000011921, -2.0, 1.75, 4.94999980927, -2.0, 5.44999980927, 2.45000004768, -2.0, -2.0, -2.0, 4.85000038147, 3.09999990463, -2.0, -2.0, -2.0], "classes": [[46.0, 62.0, 42.0], [46.0, 0.0, 0.0], [0.0, 62.0, 42.0], [0.0, 61.0, 5.0], [0.0, 58.0, 0.0], [0.0, 3.0, 5.0], [0.0, 3.0, 2.0], [0.0, 0.0, 2.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0], [0.0, 1.0, 37

### Run classification in Java:

Save the transpiled estimator:

In [7]:
with open('RandomForestClassifier.java', 'w') as f:
    f.write(output)

Download the dependencies:

In [8]:
%%bash

wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar

--2017-11-26 21:49:59--  http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
Resolving central.maven.org... 151.101.36.209
Connecting to central.maven.org|151.101.36.209|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 232932 (227K) [application/java-archive]
Saving to: 'gson-2.8.2.jar.3'

     0K .......... .......... .......... .......... .......... 21% 1.66M 0s
    50K .......... .......... .......... .......... .......... 43% 4.30M 0s
   100K .......... .......... .......... .......... .......... 65% 6.07M 0s
   150K .......... .......... .......... .......... .......... 87% 7.76M 0s
   200K .......... .......... .......                         100% 5.68M=0.06s

2017-11-26 21:49:59 (3.71 MB/s) - 'gson-2.8.2.jar.3' saved [232932/232932]



Compiling:

In [9]:
%%bash

javac -cp .:gson-2.8.2.jar RandomForestClassifier.java

Prediction:

In [10]:
%%bash

java -cp .:gson-2.8.2.jar RandomForestClassifier data.json 1 2 3 4

1
