Skip to content

Commit

Permalink
Add JS variation of the BernoulliNB estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Oct 1, 2017
1 parent 98a40a8 commit 9784d6b
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
2 changes: 1 addition & 1 deletion readme.md
Expand Up @@ -121,7 +121,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
<td><a href="http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.BernoulliNB.html#sklearn.naive_bayes.BernoulliNB">sklearn.naive_bayes.BernoulliNB</a></td>
<td align="center"></td>
<td align="center"><a href="examples/classifier/BernoulliNB/java/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
Expand Down
9 changes: 8 additions & 1 deletion sklearn_porter/classifier/BernoulliNB/__init__.py
Expand Up @@ -22,6 +22,13 @@ class BernoulliNB(Classifier):
'arr[][]': '{type}[][] {name} = {{{values}}};',
'indent': ' ',
},
'js': {
'type': '{0}',
'arr': '[{0}]',
'arr[]': 'var {name} = [{values}];',
'arr[][]': 'var {name} = [{values}];',
'indent': ' ',
}
}
# @formatter:on

Expand Down Expand Up @@ -139,7 +146,7 @@ def create_method(self):
:return out : string
The built method as string.
"""
n_indents = 1 if self.target_language in ['java'] else 0
n_indents = 1 if self.target_language in ['java', 'js'] else 0
temp_method = self.temp('method.predict', n_indents=n_indents,
skipping=True)
out = temp_method.format(**self.__dict__)
Expand Down
13 changes: 13 additions & 0 deletions sklearn_porter/classifier/BernoulliNB/templates/js/class.txt
@@ -0,0 +1,13 @@
var {class_name} = function() {{

{method}

}};

if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {{
if (process.argv.length - 2 == {n_features}) {{
var argv = process.argv.slice(2);
var prediction = new {class_name}().{method_name}(argv);
console.log(prediction);
}}
}}
@@ -0,0 +1,33 @@
this.{method_name} = function(atts) {{

{priors}
{neg_probs}
{del_probs}

var jll = new Array({n_classes});
for (var i = 0; i < {n_classes}; i++) {{
var sum = 0.;
for (var j = 0; j < {n_features}; j++) {{
sum += atts[i] * delProbs[j][i];
}}
jll[i] = sum;
}}
for (var i = 0; i < {n_classes}; i++) {{
var sum = 0.;
for (var j = 0; j < {n_features}; j++) {{
sum += negProbs[i][j];
}}
jll[i] += priors[i] + sum;
}}

var highestLikeli = -1;
var classIndex = -1;
for (var i = 0; i < {n_classes}; i++) {{
if (jll[i] > highestLikeli) {{
highestLikeli = jll[i];
classIndex = i;
}}
}}
return classIndex;
}};

38 changes: 38 additions & 0 deletions tests/classifier/BernoulliNB/BernoulliNBJSTest.py
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-

import unittest

from sklearn.naive_bayes import BernoulliNB

from ..Classifier import Classifier
from ...language.JavaScript import JavaScript


class BernoulliNBJSTest(JavaScript, Classifier, unittest.TestCase):

def setUp(self):
super(BernoulliNBJSTest, self).setUp()
self.mdl = BernoulliNB()

def tearDown(self):
super(BernoulliNBJSTest, self).tearDown()

@unittest.skip('BernoulliNB is just suitable for discrete data.')
def test_random_features_w_iris_data(self):
pass

@unittest.skip('BernoulliNB is just suitable for discrete data.')
def test_existing_features_w_binary_data(self):
pass

@unittest.skip('BernoulliNB is just suitable for discrete data.')
def test_random_features_w_binary_data(self):
pass

@unittest.skip('BernoulliNB is just suitable for discrete data.')
def test_random_features_w_digits_data(self):
pass

@unittest.skip('BernoulliNB is just suitable for discrete data.')
def test_existing_features_w_digits_data(self):
pass

0 comments on commit 9784d6b

Please sign in to comment.