-
Notifications
You must be signed in to change notification settings - Fork 99
/
plot_pipeline_lightgbm.py
140 lines (114 loc) · 4.25 KB
/
plot_pipeline_lightgbm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""
.. _example-lightgbm:
Convert a pipeline with a LightGbm model
========================================
.. index:: LightGbm
*sklearn-onnx* only converts *scikit-learn* models into *ONNX*
but many libraries implement *scikit-learn* API so that their models
can be included in a *scikit-learn* pipeline. This example considers
a pipeline including a *LightGbm* model. *sklearn-onnx* can convert
the whole pipeline as long as it knows the converter associated to
a *LGBMClassifier*. Let's see how to do it.
A couple of errors might happen while trying to convert
your own pipeline, some of them are described
and explained in :ref:`errors-pipeline`.
.. contents::
:local:
Train a LightGBM classifier
+++++++++++++++++++++++++++
"""
import lightgbm
import onnxmltools
import skl2onnx
import onnxruntime
import onnx
import sklearn
import matplotlib.pyplot as plt
import os
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes # noqa
from onnxmltools.convert.lightgbm.operator_converters.LightGbm import convert_lightgbm # noqa
import onnxmltools.convert.common.data_types
from skl2onnx.common.data_types import FloatTensorType
import numpy
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from lightgbm import LGBMClassifier
data = load_iris()
X = data.data[:, :2]
y = data.target
ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()
pipe = Pipeline([('scaler', StandardScaler()),
('lgbm', LGBMClassifier(n_estimators=3))])
pipe.fit(X, y)
######################################
# Register the converter for LGBMClassifier
# +++++++++++++++++++++++++++++++++++++++++
#
# The converter is implemented in *onnxmltools*:
# `onnxmltools...LightGbm.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# lightgbm/operator_converters/LightGbm.py>`_.
# and the shape calculator:
# `onnxmltools...Classifier.py
# <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
# lightgbm/shape_calculators/Classifier.py>`_.
##############################################
# Then we import the converter and shape calculator.
###########################
# Let's register the new converter.
update_registered_converter(LGBMClassifier, 'LightGbmLGBMClassifier',
calculate_linear_classifier_output_shapes,
convert_lightgbm)
##################################
# Convert again
# +++++++++++++
model_onnx = convert_sklearn(pipe, 'pipeline_lightgbm',
[('input', FloatTensorType([None, 2]))])
# And save.
with open("pipeline_lightgbm.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())
###########################
# Compare the predictions
# +++++++++++++++++++++++
#
# Predictions with LightGbm.
print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))
##########################
# Predictions with onnxruntime.
sess = rt.InferenceSession("pipeline_lightgbm.onnx")
pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])
##################################
# Display the ONNX graph
# ++++++++++++++++++++++
pydot_graph = GetPydotGraph(
model_onnx.graph, name=model_onnx.graph.name, rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow",
fillcolor="yellow", style="filled"))
pydot_graph.write_dot("pipeline.dot")
os.system('dot -O -Gdpi=300 -Tpng pipeline.dot')
image = plt.imread("pipeline.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis('off')
#################################
# **Versions used for this example**
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", onnxruntime.__version__)
print("skl2onnx: ", skl2onnx.__version__)
print("onnxmltools: ", onnxmltools.__version__)
print("lightgbm: ", lightgbm.__version__)