Skip to content

Commit

Permalink
Add Ruby variation of the RandomForestClassifier estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Sep 26, 2017
1 parent 84b48b9 commit 3775501
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 2 deletions.
2 changes: 1 addition & 1 deletion readme.md
Expand Up @@ -70,7 +70,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
<td align="center"><a href="examples/classifier/RandomForestClassifier/js/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center">✓</td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td><a href="http://scikit-learn.org/0.18/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html">sklearn.ensemble.ExtraTreesClassifier</a></td>
Expand Down
10 changes: 9 additions & 1 deletion sklearn_porter/classifier/RandomForestClassifier/__init__.py
Expand Up @@ -50,6 +50,14 @@ class RandomForestClassifier(Classifier):
'indent': ' ',
'join': '; ',
},
'ruby': {
'if': 'if atts[{0}] {1} {2}',
'else': 'else',
'endif': 'end',
'arr': 'classes[{0}] = {1}',
'indent': ' ',
'join': '',
},
}
# @formatter:on

Expand Down Expand Up @@ -240,7 +248,7 @@ def create_method(self):
fns = '\n'.join(fns)

# Merge generated content:
n_indents = 1 if self.target_language in ['java', 'js', 'php'] else 0
n_indents = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
temp_method = self.temp('method')
out = temp_method.format(method_name=self.method_name,
method_calls=fn_names, methods=fns,
Expand Down
@@ -0,0 +1,10 @@
class {class_name}

{method}

end

if ARGV.length == {n_features}
atts = ARGV.collect {{ |i| i.to_f }}
puts {class_name}.{method_name}(atts)
end
@@ -0,0 +1,7 @@
{methods}
def self.{method_name} (atts)
classes = Array.new({n_classes}, 0)
{method_calls}
pos_max = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
return pos_max.min
end
@@ -0,0 +1 @@
idx = {class_name}.{method_name}(atts); classes[idx] = classes[idx] + 1
@@ -0,0 +1,6 @@
def self.{method_name}_{method_id} (atts)
classes = Array.new({n_classes}, 0)
{tree_branches}
pos_max = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
return pos_max.min
end
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

from sklearn.ensemble import RandomForestClassifier

from ..Classifier import Classifier
from ...language.Ruby import Ruby


class RandomForestClassifierRubyTest(Ruby, Classifier, TestCase):

def setUp(self):
super(RandomForestClassifierRubyTest, self).setUp()
self.mdl = RandomForestClassifier(n_estimators=20, random_state=0)

def tearDown(self):
super(RandomForestClassifierRubyTest, self).tearDown()

0 comments on commit 3775501

Please sign in to comment.