# sklearn-porter

Repository: https://github.com/nok/sklearn-porter

## AdaBoostClassifier

Documentation: [sklearn.ensemble.AdaBoostClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.AdaBoostClassifier.html)

### Loading data:

In [1]:
from sklearn.datasets import load_iris

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)

((150, 4), (150,))


### Train classifier:

In [2]:
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier

base_estimator = DecisionTreeClassifier(max_depth=4, random_state=0)
clf = AdaBoostClassifier(base_estimator=base_estimator, n_estimators=100,
                         random_state=0)
clf.fit(X, y)

AdaBoostClassifier(algorithm='SAMME.R',
          base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=4,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=0,
            splitter='best'),
          learning_rate=1.0, n_estimators=100, random_state=0)

### Transpile classifier:

In [4]:
%%time

from sklearn_porter import Porter

porter = Porter(clf)
output = porter.export(export_data=True)

print(output)

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Scanner;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;


class AdaBoostClassifier {

    private class Tree {
        private int[] childrenLeft;
        private int[] childrenRight;
        private double[] thresholds;
        private int[] indices;
        private double[][] classes;

        private double[] predict (double[] features, int node) {
            if (this.thresholds[node] != -2) {
                if (features[this.indices[node]] <= this.thresholds[node]) {
                    return this.predict(features, this.childrenLeft[node]);
                } else {
                    return this.predict(features, this.childrenRight[node]);
                }
            }
            return this.classes[node];
        }
        private double[] predict (double[] features) {
            return this.pr

Parameters:

In [5]:
%%bash

cat data.json

[{"indices": [3, -2, 3, 2, 3, -2, -2, 3, -2, -2, 2, 0, -2, -2, -2], "thresholds": [0.800000011921, -2.0, 1.75, 4.94999980927, 1.65000009537, -2.0, -2.0, 1.54999995232, -2.0, -2.0, 4.85000038147, 5.94999980927, -2.0, -2.0, -2.0], "classes": [[0.333333333333, 0.333333333333, 0.333333333333], [0.333333333333, 0.0, 0.0], [0.0, 0.333333333333, 0.333333333333], [0.0, 0.326666666667, 0.0333333333333], [0.0, 0.313333333333, 0.00666666666667], [0.0, 0.313333333333, 0.0], [0.0, 0.0, 0.00666666666667], [0.0, 0.0133333333333, 0.0266666666667], [0.0, 0.0, 0.02], [0.0, 0.0133333333333, 0.00666666666667], [0.0, 0.00666666666667, 0.3], [0.0, 0.00666666666667, 0.0133333333333], [0.0, 0.00666666666667, 0.0], [0.0, 0.0, 0.0133333333333], [0.0, 0.0, 0.286666666667]], "childrenRight": [2, -1, 10, 7, 6, -1, -1, 9, -1, -1, 14, 13, -1, -1, -1], "childrenLeft": [1, -1, 3, 4, 5, -1, -1, 8, -1, -1, 11, 12, -1, -1, -1]}, {"indices": [2, 2, -2, 3, 0, -2, -2, 1, -2, -2, -2], "thresholds": [5.14999961853, 2.45000004

### Run classification in Java:

Save the estimator:

In [6]:
with open('AdaBoostClassifier.java', 'w') as f:
    f.write(output)

Download dependencies:

In [7]:
%%bash

wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar

--2017-11-26 22:27:24--  http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
Resolving central.maven.org... 151.101.36.209
Connecting to central.maven.org|151.101.36.209|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 232932 (227K) [application/java-archive]
Saving to: 'gson-2.8.2.jar.1'

     0K .......... .......... .......... .......... .......... 21% 1.55M 0s
    50K .......... .......... .......... .......... .......... 43% 4.48M 0s
   100K .......... .......... .......... .......... .......... 65% 6.49M 0s
   150K .......... .......... .......... .......... .......... 87% 6.72M 0s
   200K .......... .......... .......                         100% 6.03M=0.06s

2017-11-26 22:27:24 (3.60 MB/s) - 'gson-2.8.2.jar.1' saved [232932/232932]



Compiling:

In [8]:
%%bash

javac -cp .:gson-2.8.2.jar AdaBoostClassifier.java

Prediction:

In [9]:
%%bash

java -cp .:gson-2.8.2.jar AdaBoostClassifier data.json 1 2 3 4

2
