# sklearn-porter

Repository: https://github.com/nok/sklearn-porter

## DecisionTreeClassifier

Documentation: [sklearn.tree.DecisionTreeClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html)

### Loading data:

In [1]:
from sklearn.datasets import load_iris

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)

((150, 4), (150,))


### Train classifier:

In [2]:
from sklearn.tree import tree

clf = tree.DecisionTreeClassifier()
clf.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

### Transpile classifier:

In [4]:
%%time

from sklearn_porter import Porter

porter = Porter(clf)
output = porter.export(export_data=True)

print(output)

import com.google.gson.Gson;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.Scanner;


class DecisionTreeClassifier {

    private class Classifier {
        private int[] leftChilds;
        private int[] rightChilds;
        private double[] thresholds;
        private int[] indices;
        private int[][] classes;
    }
    private Classifier clf;

    public DecisionTreeClassifier(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
    }

    public int predict(double[] features, int node) {
        if (this.clf.thresholds[node] != -2) {
            if (features[this.clf.indices[node]] <= this.clf.thresholds[node]) {
                return predict(features, this.clf.leftChilds[node]);
            } else {
                return predict(features, this.clf.rightChilds[node]);
            }
        }
        return fi

Parameters:

In [5]:
%%bash

cat data.json

{"leftChilds": [1, -1, 3, 4, 5, -1, -1, 8, -1, 10, -1, -1, 13, 14, -1, -1, -1], "rightChilds": [2, -1, 12, 7, 6, -1, -1, 9, -1, 11, -1, -1, 16, 15, -1, -1, -1], "thresholds": [0.800000011921, -2.0, 1.75, 4.94999980927, 1.65000009537, -2.0, -2.0, 1.54999995232, -2.0, 5.44999980927, -2.0, -2.0, 4.85000038147, 5.94999980927, -2.0, -2.0, -2.0], "classes": [[50.0, 50.0, 50.0], [50.0, 0.0, 0.0], [0.0, 50.0, 50.0], [0.0, 49.0, 5.0], [0.0, 47.0, 1.0], [0.0, 47.0, 0.0], [0.0, 0.0, 1.0], [0.0, 2.0, 4.0], [0.0, 0.0, 3.0], [0.0, 2.0, 1.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 45.0], [0.0, 1.0, 2.0], [0.0, 1.0, 0.0], [0.0, 0.0, 2.0], [0.0, 0.0, 43.0]], "indices": [3, -2, 3, 2, 3, -2, -2, 3, -2, 2, -2, -2, 2, 0, -2, -2, -2]}

### Run classification in Java:

Save the estimator:

In [6]:
with open('DecisionTreeClassifier.java', 'w') as f:
    f.write(output)

Download dependencies:

In [7]:
%%bash

wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar

--2017-11-26 23:17:02--  http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
Resolving central.maven.org... 151.101.36.209
Connecting to central.maven.org|151.101.36.209|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 232932 (227K) [application/java-archive]
Saving to: 'gson-2.8.2.jar'

     0K .......... .......... .......... .......... .......... 21% 1.91M 0s
    50K .......... .......... .......... .......... .......... 43% 3.07M 0s
   100K .......... .......... .......... .......... .......... 65% 8.36M 0s
   150K .......... .......... .......... .......... .......... 87% 7.89M 0s
   200K .......... .......... .......                         100% 6.95M=0.06s

2017-11-26 23:17:02 (3.87 MB/s) - 'gson-2.8.2.jar' saved [232932/232932]



Compiling:

In [8]:
%%bash

javac -cp .:gson-2.8.2.jar DecisionTreeClassifier.java

Prediction:

In [9]:
%%bash

java -cp .:gson-2.8.2.jar DecisionTreeClassifier data.json 1 2 3 4

1
