Skip to content

Commit

Permalink
Rename embedded templates
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Oct 31, 2017
1 parent e2740fd commit b727186
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def export(self, class_name, method_name, embedded=False):
self.classes = classes

if self.target_method == 'predict':
return self.predict(class_name, method_name, embedded)
return self.predict(embedded)

def predict(self, class_name, method_name, embedded):
def predict(self, embedded):
"""
Transpile the predict method.
Expand All @@ -190,8 +190,8 @@ def predict(self, class_name, method_name, embedded):
The transpiled predict method as string.
"""
if embedded:
method = self.create_method_embedded(class_name, method_name)
out = self.create_class_embedded(method, class_name, method_name)
method = self.create_method_embedded()
out = self.create_class_embedded(method)
return out

out = self.create_class()
Expand All @@ -210,7 +210,7 @@ def create_class(self):
out = temp_class.format(**self.__dict__)
return out

def create_method_embedded(self, class_name, method_name):
def create_method_embedded(self):
"""
Build the estimator method or function.
Expand All @@ -222,14 +222,16 @@ def create_method_embedded(self, class_name, method_name):
n_indents = 1 if self.target_language in ['java', 'js',
'php', 'ruby'] else 0
branches = self.indent(self.create_tree(), n_indents=1)
temp_method = self.temp('method.embedded', n_indents=n_indents,
temp_method = self.temp('embedded.method', n_indents=n_indents,
skipping=True)
out = temp_method.format(class_name=class_name, method_name=method_name,
out = temp_method.format(class_name=self.class_name,
method_name=self.method_name,
n_classes=self.n_classes,
n_features=self.n_features,
n_classes=self.n_classes, branches=branches)
branches=branches)
return out

def create_class_embedded(self, method, class_name, method_name):
def create_class_embedded(self, method):
"""
Build the estimator class.
Expand All @@ -238,10 +240,12 @@ def create_class_embedded(self, method, class_name, method_name):
:return out : string
The built class as string.
"""
temp_class = self.temp('class.embedded')
out = temp_class.format(class_name=class_name, method_name=method_name,
method=method, n_classes=self.n_classes,
n_features=self.n_features)
temp_class = self.temp('embedded.class')
out = temp_class.format(class_name=self.class_name,
method_name=self.method_name,
n_classes=self.n_classes,
n_features=self.n_features,
method=method)
return out

def create_branches(self, left_nodes, right_nodes, threshold,
Expand Down Expand Up @@ -307,7 +311,8 @@ def create_tree(self):
:return out : string
The tree branches as string.
"""
indentation = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
indentation = 1 if self.target_language in ['java', 'js',
'php', 'ruby'] else 0
return self.create_branches(
self.estimator.tree_.children_left,
self.estimator.tree_.children_right,
Expand Down

0 comments on commit b727186

Please sign in to comment.