# sklearn-porter

Repository: https://github.com/nok/sklearn-porter

## BernoulliNB

Documentation: [sklearn.naive_bayes.BernoulliNB](http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.BernoulliNB.html)

### Loading data:

In [1]:
from sklearn.datasets import load_iris

iris_data = load_iris()
X = iris_data.data
y = iris_data.target

print(X.shape, y.shape)

((150, 4), (150,))


### Train classifier:

In [2]:
from sklearn.naive_bayes import BernoulliNB

clf = BernoulliNB()
clf.fit(X, y)

BernoulliNB(alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True)

### Transpile classifier:

In [4]:
%%time

from sklearn_porter import Porter

porter = Porter(clf)
output = porter.export(export_data=True)

print(output)

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import com.google.gson.Gson;


class BernoulliNB {

    private class Classifier {
        private double[] priors;
        private double[][] negProbs;
        private double[][] delProbs;
    }

    private Classifier clf;

    public BernoulliNB(String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        this.clf = new Gson().fromJson(jsonStr, Classifier.class);
    }

    public int predict(double[] features) {
        int nClasses = this.clf.priors.length;
        int nFeatures = this.clf.delProbs.length;

        double[] jll = new double[nClasses];
        for (int i = 0; i < nClasses; i++) {
            double sum = 0.;
            for (int j = 0; j < nFeatures; j++) {
                sum += features[i] * this.clf.delProbs[j][i];
            }
            jll[i] = sum;
        }
        for (int i = 0; i < nClasses; i++) {
     

Parameters:

In [5]:
%%bash

cat data.json

{"priors": [-1.09861228867, -1.09861228867, -1.09861228867], "delProbs": [[3.93182563272, 3.93182563272, 3.93182563272], [3.93182563272, 3.93182563272, 3.93182563272], [3.93182563272, 3.93182563272, 3.93182563272], [3.93182563272, 3.93182563272, 3.93182563272]], "negProbs": [[-3.95124371858, -3.95124371858, -3.95124371858, -3.95124371858], [-3.95124371858, -3.95124371858, -3.95124371858, -3.95124371858], [-3.95124371858, -3.95124371858, -3.95124371858, -3.95124371858]]}

### Run classification in Java:

Save the transpiled estimator:

In [6]:
with open('BernoulliNB.java', 'w') as f:
    f.write(output)

Download the dependencies:

In [7]:
%%bash

wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar

--2017-12-02 16:22:33--  http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar
Resolving central.maven.org... 151.101.36.209
Connecting to central.maven.org|151.101.36.209|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 232932 (227K) [application/java-archive]
Saving to: 'gson-2.8.2.jar'

     0K .......... .......... .......... .......... .......... 21% 1.65M 0s
    50K .......... .......... .......... .......... .......... 43% 2.75M 0s
   100K .......... .......... .......... .......... .......... 65% 3.53M 0s
   150K .......... .......... .......... .......... .......... 87% 1.77M 0s
   200K .......... .......... .......                         100% 2.79M=0.1s

2017-12-02 16:22:33 (2.26 MB/s) - 'gson-2.8.2.jar' saved [232932/232932]



Compiling:

In [8]:
%%bash

javac -cp .:gson-2.8.2.jar BernoulliNB.java

Prediction:

In [9]:
%%bash

java -cp .:gson-2.8.2.jar BernoulliNB data.json 1 2 3 4

2
