-
Notifications
You must be signed in to change notification settings - Fork 274
/
lightgbm.py
129 lines (101 loc) · 4.79 KB
/
lightgbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
Converters for LightGBM models.
"""
import numpy as np
from onnxconverter_common.registration import register_converter
from . import constants
from ._gbdt_commons import convert_gbdt_classifier_common, convert_gbdt_common
from ._tree_commons import TreeParameters
def _tree_traversal(node, lefts, rights, features, thresholds, values, count):
"""
Recursive function for parsing a tree and filling the input data structures.
"""
if "left_child" in node:
features.append(node["split_feature"])
thresholds.append(node["threshold"])
values.append([-1])
lefts.append(count + 1)
rights.append(-1)
pos = len(rights) - 1
count = _tree_traversal(node["left_child"], lefts, rights, features, thresholds, values, count + 1)
rights[pos] = count + 1
return _tree_traversal(node["right_child"], lefts, rights, features, thresholds, values, count + 1)
else:
features.append(0)
thresholds.append(0)
values.append([node["leaf_value"]])
lefts.append(-1)
rights.append(-1)
return count
def _get_tree_parameters(tree_info, extra_config):
"""
Parse the tree and returns an in-memory friendly representation of its structure.
"""
lefts = []
rights = []
features = []
thresholds = []
values = []
_tree_traversal(tree_info["tree_structure"], lefts, rights, features, thresholds, values, 0)
return TreeParameters(lefts, rights, features, thresholds, values)
def convert_sklearn_lgbm_classifier(operator, device, extra_config):
"""
Converter for `lightgbm.LGBMClassifier` (trained using the Sklearn API).
Args:
operator: An operator wrapping a `lightgbm.LGBMClassifier` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None, "Cannot convert None operator"
if operator.raw_operator.boosting_type == "rf":
raise RuntimeError("Unable to directly convert this model. " "It should be converted into ONNX first.")
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"]
n_classes = operator.raw_operator._n_classes
return convert_gbdt_classifier_common(
operator, tree_infos, _get_tree_parameters, n_features, n_classes, extra_config=extra_config
)
def convert_sklearn_lgbm_regressor(operator, device, extra_config):
"""
Converter for `lightgbm.LGBMRegressor` and `lightgbm.LGBMRanker` (trained using the Sklearn API).
Args:
operator: An operator wrapping a `lightgbm.LGBMRegressor` or `lightgbm.LGBMRanker` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None, "Cannot convert None operator"
# Get tree information out of the model.
n_features = operator.raw_operator._n_features
tree_infos = operator.raw_operator.booster_.dump_model()["tree_info"]
if operator.raw_operator._objective == "tweedie":
extra_config[constants.POST_TRANSFORM] = constants.TWEEDIE
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
def convert_lgbm_booster(operator, device, extra_config):
"""
Converter for `lightgbm.Booster`
Args:
operator: An operator wrapping a `lightgbm.Booster` model
device: String defining the type of device the converted operator should be run on
extra_config: Extra configuration used to select the best conversion strategy
Returns:
A PyTorch model
"""
assert operator is not None, "Cannot convert None operator"
# Get tree information out of the model.
n_features = len(operator.raw_operator.feature_name())
tree_infos = operator.raw_operator.dump_model()["tree_info"]
return convert_gbdt_common(operator, tree_infos, _get_tree_parameters, n_features, extra_config=extra_config)
# Register the converters.
register_converter("SklearnLGBMClassifier", convert_sklearn_lgbm_classifier)
register_converter("SklearnLGBMRanker", convert_sklearn_lgbm_regressor)
register_converter("SklearnLGBMRegressor", convert_sklearn_lgbm_regressor)
register_converter("SklearnBooster", convert_lgbm_booster)