diff --git a/readme.md b/readme.md
index b87a2b6d..15ae092c 100644
--- a/readme.md
+++ b/readme.md
@@ -121,7 +121,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
sklearn.naive_bayes.BernoulliNB |
|
✓ |
- |
+ ✓ |
|
|
|
diff --git a/sklearn_porter/classifier/BernoulliNB/__init__.py b/sklearn_porter/classifier/BernoulliNB/__init__.py
index b61661ae..e429d134 100644
--- a/sklearn_porter/classifier/BernoulliNB/__init__.py
+++ b/sklearn_porter/classifier/BernoulliNB/__init__.py
@@ -22,6 +22,13 @@ class BernoulliNB(Classifier):
'arr[][]': '{type}[][] {name} = {{{values}}};',
'indent': ' ',
},
+ 'js': {
+ 'type': '{0}',
+ 'arr': '[{0}]',
+ 'arr[]': 'var {name} = [{values}];',
+ 'arr[][]': 'var {name} = [{values}];',
+ 'indent': ' ',
+ }
}
# @formatter:on
@@ -139,7 +146,7 @@ def create_method(self):
:return out : string
The built method as string.
"""
- n_indents = 1 if self.target_language in ['java'] else 0
+ n_indents = 1 if self.target_language in ['java', 'js'] else 0
temp_method = self.temp('method.predict', n_indents=n_indents,
skipping=True)
out = temp_method.format(**self.__dict__)
diff --git a/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt b/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt
new file mode 100644
index 00000000..49582d4a
--- /dev/null
+++ b/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt
@@ -0,0 +1,13 @@
+var {class_name} = function() {{
+
+ {method}
+
+}};
+
+if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {{
+ if (process.argv.length - 2 == {n_features}) {{
+ var argv = process.argv.slice(2);
+ var prediction = new {class_name}().{method_name}(argv);
+ console.log(prediction);
+ }}
+}}
\ No newline at end of file
diff --git a/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt b/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt
new file mode 100644
index 00000000..35e64862
--- /dev/null
+++ b/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt
@@ -0,0 +1,33 @@
+this.{method_name} = function(atts) {{
+
+ {priors}
+ {neg_probs}
+ {del_probs}
+
+ var jll = new Array({n_classes});
+ for (var i = 0; i < {n_classes}; i++) {{
+ var sum = 0.;
+ for (var j = 0; j < {n_features}; j++) {{
+ sum += atts[i] * delProbs[j][i];
+ }}
+ jll[i] = sum;
+ }}
+ for (var i = 0; i < {n_classes}; i++) {{
+ var sum = 0.;
+ for (var j = 0; j < {n_features}; j++) {{
+ sum += negProbs[i][j];
+ }}
+ jll[i] += priors[i] + sum;
+ }}
+
+ var highestLikeli = -1;
+ var classIndex = -1;
+ for (var i = 0; i < {n_classes}; i++) {{
+ if (jll[i] > highestLikeli) {{
+ highestLikeli = jll[i];
+ classIndex = i;
+ }}
+ }}
+ return classIndex;
+}};
+
diff --git a/tests/classifier/BernoulliNB/BernoulliNBJSTest.py b/tests/classifier/BernoulliNB/BernoulliNBJSTest.py
new file mode 100644
index 00000000..627f151d
--- /dev/null
+++ b/tests/classifier/BernoulliNB/BernoulliNBJSTest.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+
+import unittest
+
+from sklearn.naive_bayes import BernoulliNB
+
+from ..Classifier import Classifier
+from ...language.JavaScript import JavaScript
+
+
+class BernoulliNBJSTest(JavaScript, Classifier, unittest.TestCase):
+
+ def setUp(self):
+ super(BernoulliNBJSTest, self).setUp()
+ self.mdl = BernoulliNB()
+
+ def tearDown(self):
+ super(BernoulliNBJSTest, self).tearDown()
+
+ @unittest.skip('BernoulliNB is just suitable for discrete data.')
+ def test_random_features_w_iris_data(self):
+ pass
+
+ @unittest.skip('BernoulliNB is just suitable for discrete data.')
+ def test_existing_features_w_binary_data(self):
+ pass
+
+ @unittest.skip('BernoulliNB is just suitable for discrete data.')
+ def test_random_features_w_binary_data(self):
+ pass
+
+ @unittest.skip('BernoulliNB is just suitable for discrete data.')
+ def test_random_features_w_digits_data(self):
+ pass
+
+ @unittest.skip('BernoulliNB is just suitable for discrete data.')
+ def test_existing_features_w_digits_data(self):
+ pass
\ No newline at end of file