diff --git a/readme.md b/readme.md index b87a2b6d..15ae092c 100644 --- a/readme.md +++ b/readme.md @@ -121,7 +121,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e sklearn.naive_bayes.BernoulliNB - + ✓ diff --git a/sklearn_porter/classifier/BernoulliNB/__init__.py b/sklearn_porter/classifier/BernoulliNB/__init__.py index b61661ae..e429d134 100644 --- a/sklearn_porter/classifier/BernoulliNB/__init__.py +++ b/sklearn_porter/classifier/BernoulliNB/__init__.py @@ -22,6 +22,13 @@ class BernoulliNB(Classifier): 'arr[][]': '{type}[][] {name} = {{{values}}};', 'indent': ' ', }, + 'js': { + 'type': '{0}', + 'arr': '[{0}]', + 'arr[]': 'var {name} = [{values}];', + 'arr[][]': 'var {name} = [{values}];', + 'indent': ' ', + } } # @formatter:on @@ -139,7 +146,7 @@ def create_method(self): :return out : string The built method as string. """ - n_indents = 1 if self.target_language in ['java'] else 0 + n_indents = 1 if self.target_language in ['java', 'js'] else 0 temp_method = self.temp('method.predict', n_indents=n_indents, skipping=True) out = temp_method.format(**self.__dict__) diff --git a/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt b/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt new file mode 100644 index 00000000..49582d4a --- /dev/null +++ b/sklearn_porter/classifier/BernoulliNB/templates/js/class.txt @@ -0,0 +1,13 @@ +var {class_name} = function() {{ + + {method} + +}}; + +if (typeof process !== 'undefined' && typeof process.argv !== 'undefined') {{ + if (process.argv.length - 2 == {n_features}) {{ + var argv = process.argv.slice(2); + var prediction = new {class_name}().{method_name}(argv); + console.log(prediction); + }} +}} \ No newline at end of file diff --git a/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt b/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt new file mode 100644 index 00000000..35e64862 --- /dev/null +++ b/sklearn_porter/classifier/BernoulliNB/templates/js/method.predict.txt @@ -0,0 +1,33 @@ +this.{method_name} = function(atts) {{ + + {priors} + {neg_probs} + {del_probs} + + var jll = new Array({n_classes}); + for (var i = 0; i < {n_classes}; i++) {{ + var sum = 0.; + for (var j = 0; j < {n_features}; j++) {{ + sum += atts[i] * delProbs[j][i]; + }} + jll[i] = sum; + }} + for (var i = 0; i < {n_classes}; i++) {{ + var sum = 0.; + for (var j = 0; j < {n_features}; j++) {{ + sum += negProbs[i][j]; + }} + jll[i] += priors[i] + sum; + }} + + var highestLikeli = -1; + var classIndex = -1; + for (var i = 0; i < {n_classes}; i++) {{ + if (jll[i] > highestLikeli) {{ + highestLikeli = jll[i]; + classIndex = i; + }} + }} + return classIndex; +}}; + diff --git a/tests/classifier/BernoulliNB/BernoulliNBJSTest.py b/tests/classifier/BernoulliNB/BernoulliNBJSTest.py new file mode 100644 index 00000000..627f151d --- /dev/null +++ b/tests/classifier/BernoulliNB/BernoulliNBJSTest.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +import unittest + +from sklearn.naive_bayes import BernoulliNB + +from ..Classifier import Classifier +from ...language.JavaScript import JavaScript + + +class BernoulliNBJSTest(JavaScript, Classifier, unittest.TestCase): + + def setUp(self): + super(BernoulliNBJSTest, self).setUp() + self.mdl = BernoulliNB() + + def tearDown(self): + super(BernoulliNBJSTest, self).tearDown() + + @unittest.skip('BernoulliNB is just suitable for discrete data.') + def test_random_features_w_iris_data(self): + pass + + @unittest.skip('BernoulliNB is just suitable for discrete data.') + def test_existing_features_w_binary_data(self): + pass + + @unittest.skip('BernoulliNB is just suitable for discrete data.') + def test_random_features_w_binary_data(self): + pass + + @unittest.skip('BernoulliNB is just suitable for discrete data.') + def test_random_features_w_digits_data(self): + pass + + @unittest.skip('BernoulliNB is just suitable for discrete data.') + def test_existing_features_w_digits_data(self): + pass \ No newline at end of file