Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example to create a custom converter for a NMF transformer (#167)
add an example on how to create an ONNX function for a NMF decomposition
- Loading branch information
Showing
11 changed files
with
348 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
""" | ||
Custom Operator for NMF Decomposition | ||
===================================== | ||
`NMF <https://scikit-learn.org/stable/modules/generated/ | ||
sklearn.decomposition.NMF.html>`_ factorizes an input matrix | ||
into two matrices *W, H* of rank *k* so that :math:`WH \\sim M``. | ||
:math:`M=(m_{ij})` may be a binary matrix where *i* is a user | ||
and *j* a product he bought. The prediction | ||
function depends on whether or not the user needs a | ||
recommandation for an existing user or a new user. | ||
This example addresses the first case. | ||
The second case is more complex as it theoretically | ||
requires the estimation of a new matrix *W* with a | ||
gradient descent. | ||
.. contents:: | ||
:local: | ||
Building a simple model | ||
+++++++++++++++++++++++ | ||
""" | ||
|
||
import os | ||
import skl2onnx | ||
import onnxruntime | ||
import sklearn | ||
from sklearn.decomposition import NMF | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer | ||
import onnx | ||
from skl2onnx.algebra.onnx_ops import ( | ||
OnnxArrayFeatureExtractor, OnnxMul, OnnxReduceSum) | ||
from skl2onnx.common.data_types import FloatTensorType | ||
from onnxruntime import InferenceSession | ||
|
||
|
||
mat = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], | ||
[1, 0, 0, 0], [1, 0, 0, 0]], dtype=np.float64) | ||
mat[:mat.shape[1], :] += np.identity(mat.shape[1]) | ||
|
||
mod = NMF(n_components=2) | ||
W = mod.fit_transform(mat) | ||
H = mod.components_ | ||
pred = mod.inverse_transform(W) | ||
|
||
print("original predictions") | ||
exp = [] | ||
for i in range(mat.shape[0]): | ||
for j in range(mat.shape[1]): | ||
exp.append((i, j, pred[i, j])) | ||
|
||
print(exp) | ||
|
||
####################### | ||
# Let's rewrite the prediction in a way it is closer | ||
# to the function we need to convert into ONNX. | ||
|
||
|
||
def predict(W, H, row_index, col_index): | ||
return np.dot(W[row_index, :], H[:, col_index]) | ||
|
||
|
||
got = [] | ||
for i in range(mat.shape[0]): | ||
for j in range(mat.shape[1]): | ||
got.append((i, j, predict(W, H, i, j))) | ||
|
||
print(got) | ||
|
||
|
||
################################# | ||
# Conversion into ONNX | ||
# ++++++++++++++++++++ | ||
# | ||
# There is no implemented converter for | ||
# `NMF <https://scikit-learn.org/stable/modules/generated/ | ||
# sklearn.decomposition.NMF.html>`_ as the function we plan | ||
# to convert is not transformer or a predictor. | ||
# The following converter does not need to be registered, | ||
# it just creates an ONNX graph equivalent to function | ||
# *predict* implemented above. | ||
|
||
|
||
def nmf_to_onnx(W, H): | ||
""" | ||
The function converts a NMF described by matrices | ||
*W*, *H* (*WH* approximate training data *M*). | ||
into a function which takes two indices *(i, j)* | ||
and returns the predictions for it. It assumes | ||
these indices applies on the training data. | ||
""" | ||
col = OnnxArrayFeatureExtractor(H, 'col') | ||
row = OnnxArrayFeatureExtractor(W.T, 'row') | ||
dot = OnnxMul(col, row) | ||
res = OnnxReduceSum(dot, output_names="rec") | ||
indices_type = np.array([0], dtype=np.int64) | ||
onx = res.to_onnx(inputs={'col': indices_type, | ||
'row': indices_type}, | ||
outputs=[('rec', FloatTensorType((1, 1)))]) | ||
return onx | ||
|
||
|
||
model_onnx = nmf_to_onnx(W, H) | ||
print(model_onnx) | ||
|
||
######################################## | ||
# Let's compute prediction with it. | ||
|
||
sess = InferenceSession(model_onnx.SerializeToString()) | ||
|
||
|
||
def predict_onnx(sess, row_indices, col_indices): | ||
res = sess.run(None, | ||
{'col': col_indices, | ||
'row': row_indices}) | ||
return res | ||
|
||
|
||
onnx_preds = [] | ||
for i in range(mat.shape[0]): | ||
for j in range(mat.shape[1]): | ||
row_indices = np.array([i], dtype=np.int64) | ||
col_indices = np.array([j], dtype=np.int64) | ||
pred = predict_onnx(sess, row_indices, col_indices)[0] | ||
onnx_preds.append((i, j, pred[0, 0])) | ||
|
||
print(onnx_preds) | ||
|
||
|
||
################################### | ||
# The ONNX graph looks like the following. | ||
pydot_graph = GetPydotGraph( | ||
model_onnx.graph, name=model_onnx.graph.name, | ||
rankdir="TB", node_producer=GetOpNodeProducer("docstring")) | ||
pydot_graph.write_dot("graph_nmf.dot") | ||
os.system('dot -O -Tpng graph_nmf.dot') | ||
image = plt.imread("graph_nmf.dot.png") | ||
plt.imshow(image) | ||
plt.axis('off') | ||
|
||
################################# | ||
# **Versions used for this example** | ||
|
||
print("numpy:", np.__version__) | ||
print("scikit-learn:", sklearn.__version__) | ||
print("onnx: ", onnx.__version__) | ||
print("onnxruntime: ", onnxruntime.__version__) | ||
print("skl2onnx: ", skl2onnx.__version__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import unittest | ||
import numpy as np | ||
from skl2onnx.algebra.type_helper import _guess_type | ||
from skl2onnx.common.data_types import ( | ||
FloatTensorType, Int64TensorType, | ||
Int32TensorType, StringTensorType | ||
) | ||
|
||
|
||
class TestAlgebraTestHelper(unittest.TestCase): | ||
|
||
def test_guess_type(self): | ||
dtypes = [ | ||
(np.int32, Int32TensorType), | ||
(np.int64, Int64TensorType), | ||
(np.float32, FloatTensorType), | ||
(np.str, StringTensorType) | ||
] | ||
for dtype, exp in dtypes: | ||
if dtype == np.str: | ||
mat = np.empty((3, 3), dtype=dtype) | ||
mat[:, :] = "" | ||
else: | ||
mat = np.zeros((3, 3), dtype=dtype) | ||
res = _guess_type(mat) | ||
assert isinstance(res, exp) | ||
|
||
dtypes = [np.float64] | ||
for dtype in dtypes: | ||
mat = np.zeros((3, 3), dtype=dtype) | ||
try: | ||
_guess_type(mat) | ||
raise AssertionError("It should fail for type " | ||
"{}".format(dtype)) | ||
except NotImplementedError: | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.