Skip to content

Commit

Permalink
renaming trees to match paper
Browse files Browse the repository at this point in the history
  • Loading branch information
ksaur committed Mar 27, 2020
1 parent cedecba commit b61d33c
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 139 deletions.
38 changes: 19 additions & 19 deletions hummingbird/operator_converters/_tree_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,33 @@ def find_depth(node, current_depth):


class TreeImpl(Enum):
batch = 1
beam = 2
beampp = 3
gemm = 1
tree_trav = 2
perf_tree_trav = 3


# @low and @high are optimization parameters
# TODO: document and explain these choices
def get_gbdt_by_config_or_depth(extra_config, max_depth, low=3, high=10):
if 'tree_implementation' not in extra_config:
if max_depth is not None and max_depth <= low:
return TreeImpl.batch
return TreeImpl.gemm
elif max_depth is not None and max_depth <= high:
return TreeImpl.beam
return TreeImpl.tree_trav
else:
return TreeImpl.beampp

if extra_config['tree_implementation'] == 'batch':
return TreeImpl.batch
elif extra_config['tree_implementation'] == 'beam':
return TreeImpl.beam
elif extra_config['tree_implementation'] == 'beam++':
return TreeImpl.beampp
return TreeImpl.perf_tree_trav

if extra_config['tree_implementation'] == 'gemm':
return TreeImpl.gemm
elif extra_config['tree_implementation'] == 'tree_trav':
return TreeImpl.tree_trav
elif extra_config['tree_implementation'] == 'perf_tree_trav':
return TreeImpl.perf_tree_trav
else:
raise ValueError("Tree implementation {} not found".format(extra_config))


def get_parameters_for_beam_generic(lefts, rights, features, thresholds, values):
def get_parameters_for_tree_trav_generic(lefts, rights, features, thresholds, values):
"""This is used by all trees."""
ids = [i for i in range(len(lefts))]
nodes = list(zip(ids, lefts, rights, features, thresholds, values))
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_parameters_for_beam_generic(lefts, rights, features, thresholds, values)
return [depth, nodes_map, ids, lefts, rights, features, thresholds, values]


def get_parameters_for_beam_sklearn_estimators(tree):
def get_parameters_for_tree_trav_sklearn_estimators(tree):
"""This function is used by sklearn-based trees.
Includes SklearnRandomForestClassifier/Regressor and SklearnGradientBoostingClassifier
Expand All @@ -124,10 +124,10 @@ def get_parameters_for_beam_sklearn_estimators(tree):
if values.shape[1] > 1:
values /= np.sum(values, axis=1, keepdims=True)

return get_parameters_for_beam_generic(lefts, rights, features, thresholds, values)
return get_parameters_for_tree_trav_generic(lefts, rights, features, thresholds, values)


def get_parameters_for_batch_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits):
def get_parameters_for_gemm_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits):
"""This is used by all trees."""

hidden_weights = []
Expand Down Expand Up @@ -184,7 +184,7 @@ def get_parameters_for_batch_generic(lefts, rights, features, thresholds, values
return weights, biases


def get_parameters_for_batch(tree):
def get_parameters_for_gemm(tree):
"""This function is used by sklearn-based trees.
Includes SklearnRandomForestClassifier/Regressor and SklearnGradientBoostingClassifier
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_parameters_for_batch(tree):
weights.append(np.array(hidden_weights).astype("float32"))
biases.append(np.array(hidden_biases).astype("float32"))
n_splits = len(hidden_weights)
return get_parameters_for_batch_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)
return get_parameters_for_gemm_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)


class BatchedTreeEnsemble(torch.nn.Module):
Expand Down
13 changes: 7 additions & 6 deletions hummingbird/operator_converters/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import torch
import numpy as np
from ._tree_commons import get_parameters_for_batch, get_parameters_for_beam_sklearn_estimators, get_gbdt_by_config_or_depth
from ._tree_commons import get_parameters_for_gemm, \
get_parameters_for_tree_trav_sklearn_estimators, get_gbdt_by_config_or_depth
from ._tree_commons import BatchedTreeEnsemble, BeamTreeEnsemble, BeamPPTreeEnsemble, TreeImpl
from ..common._registration import register_converter

Expand Down Expand Up @@ -269,14 +270,14 @@ def convert_sklearn_gbdt_classifier(operator, device, extra_config):
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth, low=4)

# TODO: automatically find the max tree depth by traversing the trees without relying on user input.
if tree_type == TreeImpl.batch:
net_parameters = [get_parameters_for_batch(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.gemm:
net_parameters = [get_parameters_for_gemm(e) for e in sklearn_rf_classifier.estimators_]
return BatchGBDTClassifier(net_parameters, n_features, classes_list, learning_rate, alpha, device)

net_parameters = [get_parameters_for_beam_sklearn_estimators(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.beampp:
net_parameters = [get_parameters_for_tree_trav_sklearn_estimators(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.perf_tree_trav:
return BeamPPGBDTClassifier(net_parameters, n_features, classes_list, learning_rate, alpha, device)
else: # Remaining possible case: tree_type == TreeImpl.beam
else: # Remaining possible case: tree_type == TreeImpl.tree_trav
if sklearn_rf_classifier.max_depth is None:
warnings.warn("GBDT model does not have a defined max_depth value. Consider setting one as it "
"will help the translator to pick a better translation method")
Expand Down
30 changes: 15 additions & 15 deletions hummingbird/operator_converters/lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hummingbird.operator_converters.gbdt import BeamPPGBDTRegressor, BeamGBDTClassifier, BeamGBDTRegressor

from ._tree_commons import get_gbdt_by_config_or_depth, TreeImpl
from ._tree_commons import get_parameters_for_beam_generic, get_parameters_for_batch_generic
from ._tree_commons import get_parameters_for_tree_trav_generic, get_parameters_for_gemm_generic
from ..common._registration import register_converter


Expand All @@ -34,7 +34,7 @@ def _tree_traversal(node, ls, rs, fs, ts, vs, count):
return count


def _get_tree_parameters_for_batch(tree_info, n_features):
def _get_tree_parameters_for_gemm(tree_info, n_features):
lefts = []
rights = []
features = []
Expand All @@ -61,10 +61,10 @@ def _get_tree_parameters_for_batch(tree_info, n_features):

# second hidden layer has ANDs for each leaf of the decision tree.
# depth first enumeration of the tree in order to determine the AND by the path.
return get_parameters_for_batch_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)
return get_parameters_for_gemm_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)


def _get_tree_parameters_for_beam(tree_info):
def _get_tree_parameters_for_tree_trav(tree_info):
lefts = []
rights = []
features = []
Expand All @@ -73,7 +73,7 @@ def _get_tree_parameters_for_beam(tree_info):
_tree_traversal(tree_info['tree_structure'], lefts,
rights, features, thresholds, values, 0)

return get_parameters_for_beam_generic(lefts, rights, features, thresholds, values)
return get_parameters_for_tree_trav_generic(lefts, rights, features, thresholds, values)


def convert_sklearn_lgbm_classifier(operator, device, extra_config):
Expand All @@ -89,14 +89,14 @@ def convert_sklearn_lgbm_classifier(operator, device, extra_config):
max_depth = operator.raw_operator.max_depth # TODO FIXME this should be a call to max_depth and NOT fall through!
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth)

if tree_type == TreeImpl.batch:
net_parameters = [_get_tree_parameters_for_batch(tree_info, n_features) for tree_info in tree_infos]
if tree_type == TreeImpl.gemm:
net_parameters = [_get_tree_parameters_for_gemm(tree_info, n_features) for tree_info in tree_infos]
return BatchGBDTClassifier(net_parameters, n_features, classes, device=device)

net_parameters = [_get_tree_parameters_for_beam(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.beam:
net_parameters = [_get_tree_parameters_for_tree_trav(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.tree_trav:
return BeamGBDTClassifier(net_parameters, n_features, classes, device=device)
else: # Remaining possible case: tree_type == TreeImpl.beampp
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return BeamPPGBDTClassifier(net_parameters, n_features, classes, device=device)


Expand All @@ -106,14 +106,14 @@ def convert_sklearn_lgbm_regressor(operator, device, extra_config):
max_depth = operator.raw_operator.max_depth # TODO FIXME this should be a call to max_depth and NOT fall through!
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth)

if tree_type == TreeImpl.batch:
net_parameters = [_get_tree_parameters_for_batch(tree_info, n_features) for tree_info in tree_infos]
if tree_type == TreeImpl.gemm:
net_parameters = [_get_tree_parameters_for_gemm(tree_info, n_features) for tree_info in tree_infos]
return BatchGBDTRegressor(net_parameters, n_features, device=device)

net_parameters = [_get_tree_parameters_for_beam(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.beam:
net_parameters = [_get_tree_parameters_for_tree_trav(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.tree_trav:
return BeamGBDTRegressor(net_parameters, n_features, device=device)
else: # Remaining possible case: tree_type == TreeImpl.beampp
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return BeamPPGBDTRegressor(net_parameters, n_features, device=device)


Expand Down
22 changes: 11 additions & 11 deletions hummingbird/operator_converters/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from ._tree_commons import get_parameters_for_batch, get_parameters_for_beam_sklearn_estimators, find_depth, Node
from ._tree_commons import get_parameters_for_gemm, get_parameters_for_tree_trav_sklearn_estimators, find_depth, Node
from ._tree_commons import BatchedTreeEnsemble, BeamTreeEnsemble, BeamPPTreeEnsemble
from ._tree_commons import TreeImpl, get_gbdt_by_config_or_depth
from ..common._registration import register_converter
Expand Down Expand Up @@ -183,16 +183,16 @@ def convert_sklearn_random_forest_classifier(operator, device, extra_config):
max_depth = max_depth = find_max_depth(operator)
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth, low=4)

if tree_type == TreeImpl.batch:
net_parameters = [get_parameters_for_batch(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.gemm:
net_parameters = [get_parameters_for_gemm(e) for e in sklearn_rf_classifier.estimators_]
return BatchRandomForestClassifier(
net_parameters, sklearn_rf_classifier.n_features_, operator.raw_operator.classes_.tolist(), device)

net_parameters = [get_parameters_for_beam_sklearn_estimators(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.beam:
net_parameters = [get_parameters_for_tree_trav_sklearn_estimators(e) for e in sklearn_rf_classifier.estimators_]
if tree_type == TreeImpl.tree_trav:
return BeamRandomForestClassifier(
net_parameters, sklearn_rf_classifier.n_features_, operator.raw_operator.classes_.tolist(), device)
else: # Remaining possible case: tree_type == TreeImpl.beampp
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return BeamPPRandomForestClassifier(
net_parameters, sklearn_rf_classifier.n_features_, operator.raw_operator.classes_.tolist(), device)

Expand All @@ -203,14 +203,14 @@ def convert_sklearn_random_forest_regressor(operator, device, extra_config):
# TODO max_depth should be a call to max_depth() without relying on user input.
tree_type = get_gbdt_by_config_or_depth(extra_config, sklearn_rf_regressor.max_depth, low=4)

if tree_type == TreeImpl.batch:
net_parameters = [get_parameters_for_batch(e) for e in sklearn_rf_regressor.estimators_]
if tree_type == TreeImpl.gemm:
net_parameters = [get_parameters_for_gemm(e) for e in sklearn_rf_regressor.estimators_]
return BatchRandomForestRegressor(net_parameters, sklearn_rf_regressor.n_features_, device)

net_parameters = [get_parameters_for_beam_sklearn_estimators(e) for e in sklearn_rf_regressor.estimators_]
if tree_type == TreeImpl.beampp:
net_parameters = [get_parameters_for_tree_trav_sklearn_estimators(e) for e in sklearn_rf_regressor.estimators_]
if tree_type == TreeImpl.perf_tree_trav:
return BeamPPRandomForestRegressor(net_parameters, sklearn_rf_regressor.n_features_, device)
else: # Remaining possible case: tree_type == TreeImpl.beam
else: # Remaining possible case: tree_type == TreeImpl.tree_trav
if sklearn_rf_regressor.max_depth is None: # TODO: remove these comments when we call max_depth()
warnings.warn("RandomForest model does not have a defined max_depth value. Consider setting one as it "
"will help the translator to pick a better translation method")
Expand Down
30 changes: 15 additions & 15 deletions hummingbird/operator_converters/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hummingbird.operator_converters.gbdt import BeamPPGBDTRegressor, BeamGBDTClassifier, BeamGBDTRegressor

from ._tree_commons import get_gbdt_by_config_or_depth, TreeImpl
from ._tree_commons import get_parameters_for_beam_generic, get_parameters_for_batch_generic
from ._tree_commons import get_parameters_for_tree_trav_generic, get_parameters_for_gemm_generic
from ..common._registration import register_converter


Expand Down Expand Up @@ -55,7 +55,7 @@ def _tree_traversal(tree_info, ls, rs, fs, ts, vs):
count += 1


def _get_tree_parameters_for_batch(tree_info, n_features):
def _get_tree_parameters_for_gemm(tree_info, n_features):
lefts = []
rights = []
features = []
Expand Down Expand Up @@ -92,10 +92,10 @@ def _get_tree_parameters_for_batch(tree_info, n_features):

# second hidden layer has ANDs for each leaf of the decision tree.
# depth first enumeration of the tree in order to determine the AND by the path.
return get_parameters_for_batch_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)
return get_parameters_for_gemm_generic(lefts, rights, features, thresholds, values, weights, biases, n_splits)


def _get_tree_parameters_for_beam(tree_info):
def _get_tree_parameters_for_tree_trav(tree_info):
lefts = []
rights = []
features = []
Expand All @@ -113,7 +113,7 @@ def _get_tree_parameters_for_beam(tree_info):
thresholds = [0, 0, 0]
values = [np.array([0.0]), values[0], values[0]]

return get_parameters_for_beam_generic(lefts, rights, features, thresholds, values)
return get_parameters_for_tree_trav_generic(lefts, rights, features, thresholds, values)


def convert_sklearn_xgb_classifier(operator, device, extra_config):
Expand All @@ -129,14 +129,14 @@ def convert_sklearn_xgb_classifier(operator, device, extra_config):
max_depth = operator.raw_operator.max_depth # TODO this should be a call to max_depth() and NOT fall through!
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth)

if tree_type == TreeImpl.batch:
net_parameters = [_get_tree_parameters_for_batch(tree_info, n_features) for tree_info in tree_infos]
if tree_type == TreeImpl.gemm:
net_parameters = [_get_tree_parameters_for_gemm(tree_info, n_features) for tree_info in tree_infos]
return BatchGBDTClassifier(net_parameters, n_features, classes, device=device)

net_parameters = [_get_tree_parameters_for_beam(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.beam:
net_parameters = [_get_tree_parameters_for_tree_trav(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.tree_trav:
return BeamGBDTClassifier(net_parameters, n_features, classes, device=device)
else: # Remaining possible case: tree_type == TreeImpl.beampp
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return BeamPPGBDTClassifier(net_parameters, n_features, classes, device=device)


Expand All @@ -149,14 +149,14 @@ def convert_sklearn_xgb_regressor(operator, device, extra_config):
alpha = [operator.raw_operator.base_score] # TODO in xgboost 1.0.2, remove brackets
tree_type = get_gbdt_by_config_or_depth(extra_config, max_depth)

if tree_type == TreeImpl.batch:
net_parameters = [_get_tree_parameters_for_batch(tree_info, n_features) for tree_info in tree_infos]
if tree_type == TreeImpl.gemm:
net_parameters = [_get_tree_parameters_for_gemm(tree_info, n_features) for tree_info in tree_infos]
return BatchGBDTRegressor(net_parameters, n_features, alpha=alpha, device=device)

net_parameters = [_get_tree_parameters_for_beam(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.beam:
net_parameters = [_get_tree_parameters_for_tree_trav(tree_info) for tree_info in tree_infos]
if tree_type == TreeImpl.tree_trav:
return BeamGBDTRegressor(net_parameters, n_features, alpha=alpha, device=device)
else: # Remaining possible case: tree_type == TreeImpl.beampp
else: # Remaining possible case: tree_type == TreeImpl.perf_tree_trav
return BeamPPGBDTRegressor(net_parameters, n_features, alpha=alpha, device=device)


Expand Down
36 changes: 18 additions & 18 deletions tests/test_lightgbm_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ def test_lgbm_binary_classifier_converter(self):
def test_lgbm_multi_classifier_converter(self):
self._run_lgbm_classifier_converter(3)

# batch
def test_lgbm_batch_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "batch"})
# gemm
def test_lgbm_gemm_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "gemm"})

# beam
def test_lgbm_beam_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "beam"})
# tree_trav
def test_lgbm_tree_trav_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "tree_trav"})

# beam++
def test_lgbm_beampp_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "beam++"})
# perf_tree_trav
def test_lgbm_perf_tree_trav_classifier_converter(self):
self._run_lgbm_classifier_converter(3, extra_config={"tree_implementation": "perf_tree_trav"})

def _run_lgbm_regressor_converter(self, num_classes, extra_config={}):
for max_depth in [1, 3, 8, 10, 12, None]:
Expand All @@ -75,17 +75,17 @@ def test_lgbm_binary_regressor_converter(self):
def test_lgbm_multi_regressor_converter(self):
self._run_lgbm_regressor_converter(3)

# batch
def test_lgbm_batch_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "batch"})
# gemm
def test_lgbm_gemm_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "gemm"})

# beam
def test_lgbm_beam_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "beam"})
# tree_trav
def test_lgbm_tree_trav_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "tree_trav"})

# beam++
def test_lgbm_beampp_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "beam++"})
# perf_tree_trav
def test_lgbm_perf_tree_trav_regressor_converter(self):
self._run_lgbm_regressor_converter(3, extra_config={"tree_implementation": "perf_tree_trav"})


if __name__ == "__main__":
Expand Down

0 comments on commit b61d33c

Please sign in to comment.