diff --git a/docs/api_summary.rst b/docs/api_summary.rst index 7e8303e15..a45d471a5 100644 --- a/docs/api_summary.rst +++ b/docs/api_summary.rst @@ -12,8 +12,35 @@ in *scikit-onnx*. Converters ========== +Both functions convert a *scikit-learn* model into ONNX. +The first one lets the user manually +define the input's name and types. The second one +infers this information from the training data. +These two functions are the main entry points to converter. +The rest of the API is needed if a model has no converter +implemented in this package. A new converter has then to be +registered, whether it is imported from another package +or created from scratch. + .. autofunction:: skl2onnx.convert_sklearn +.. autofunction:: skl2onnx.to_onnx + +Register a new converter +======================== + +If a model has no converter +implemented in this package, a new converter has then to be +registered, whether it is imported from another package +or created from scratch. Section :ref:`l-converter-list` +lists all available converters. + +.. autofunction:: skl2onnx.supported_converters + +.. autofunction:: skl2onnx.update_registered_converter + +.. autofunction:: skl2onnx.update_registered_parser + Manipulate ONNX graphs ====================== @@ -25,15 +52,6 @@ Manipulate ONNX graphs .. autofunction:: skl2onnx.helpers.onnx_helper.save_onnx_model -Register a new converter -======================== - -.. autofunction:: skl2onnx.supported_converters - -.. autofunction:: skl2onnx.update_registered_converter - -.. autofunction:: skl2onnx.update_registered_parser - Parsers ======= diff --git a/docs/examples/plot_nmf.py b/docs/examples/plot_nmf.py new file mode 100644 index 000000000..3d6e77672 --- /dev/null +++ b/docs/examples/plot_nmf.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Custom Operator for NMF Decomposition +===================================== + +`NMF `_ factorizes an input matrix +into two matrices *W, H* of rank *k* so that :math:`WH \\sim M``. +:math:`M=(m_{ij})` may be a binary matrix where *i* is a user +and *j* a product he bought. The prediction +function depends on whether or not the user needs a +recommandation for an existing user or a new user. +This example addresses the first case. + +The second case is more complex as it theoretically +requires the estimation of a new matrix *W* with a +gradient descent. + +.. contents:: + :local: + +Building a simple model ++++++++++++++++++++++++ + +""" + +import os +import skl2onnx +import onnxruntime +import sklearn +from sklearn.decomposition import NMF +import numpy as np +import matplotlib.pyplot as plt +from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer +import onnx +from skl2onnx.algebra.onnx_ops import ( + OnnxArrayFeatureExtractor, OnnxMul, OnnxReduceSum) +from skl2onnx.common.data_types import FloatTensorType +from onnxruntime import InferenceSession + + +mat = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], + [1, 0, 0, 0], [1, 0, 0, 0]], dtype=np.float64) +mat[:mat.shape[1], :] += np.identity(mat.shape[1]) + +mod = NMF(n_components=2) +W = mod.fit_transform(mat) +H = mod.components_ +pred = mod.inverse_transform(W) + +print("original predictions") +exp = [] +for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + exp.append((i, j, pred[i, j])) + +print(exp) + +####################### +# Let's rewrite the prediction in a way it is closer +# to the function we need to convert into ONNX. + + +def predict(W, H, row_index, col_index): + return np.dot(W[row_index, :], H[:, col_index]) + + +got = [] +for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + got.append((i, j, predict(W, H, i, j))) + +print(got) + + +################################# +# Conversion into ONNX +# ++++++++++++++++++++ +# +# There is no implemented converter for +# `NMF `_ as the function we plan +# to convert is not transformer or a predictor. +# The following converter does not need to be registered, +# it just creates an ONNX graph equivalent to function +# *predict* implemented above. + + +def nmf_to_onnx(W, H): + """ + The function converts a NMF described by matrices + *W*, *H* (*WH* approximate training data *M*). + into a function which takes two indices *(i, j)* + and returns the predictions for it. It assumes + these indices applies on the training data. + """ + col = OnnxArrayFeatureExtractor(H, 'col') + row = OnnxArrayFeatureExtractor(W.T, 'row') + dot = OnnxMul(col, row) + res = OnnxReduceSum(dot, output_names="rec") + indices_type = np.array([0], dtype=np.int64) + onx = res.to_onnx(inputs={'col': indices_type, + 'row': indices_type}, + outputs=[('rec', FloatTensorType((1, 1)))]) + return onx + + +model_onnx = nmf_to_onnx(W, H) +print(model_onnx) + +######################################## +# Let's compute prediction with it. + +sess = InferenceSession(model_onnx.SerializeToString()) + + +def predict_onnx(sess, row_indices, col_indices): + res = sess.run(None, + {'col': col_indices, + 'row': row_indices}) + return res + + +onnx_preds = [] +for i in range(mat.shape[0]): + for j in range(mat.shape[1]): + row_indices = np.array([i], dtype=np.int64) + col_indices = np.array([j], dtype=np.int64) + pred = predict_onnx(sess, row_indices, col_indices)[0] + onnx_preds.append((i, j, pred[0, 0])) + +print(onnx_preds) + + +################################### +# The ONNX graph looks like the following. +pydot_graph = GetPydotGraph( + model_onnx.graph, name=model_onnx.graph.name, + rankdir="TB", node_producer=GetOpNodeProducer("docstring")) +pydot_graph.write_dot("graph_nmf.dot") +os.system('dot -O -Tpng graph_nmf.dot') +image = plt.imread("graph_nmf.dot.png") +plt.imshow(image) +plt.axis('off') + +################################# +# **Versions used for this example** + +print("numpy:", np.__version__) +print("scikit-learn:", sklearn.__version__) +print("onnx: ", onnx.__version__) +print("onnxruntime: ", onnxruntime.__version__) +print("skl2onnx: ", skl2onnx.__version__) diff --git a/docs/index.rst b/docs/index.rst index a88e240ff..ad79af788 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,7 +26,6 @@ toolkits into `ONNX `_. pipeline parameterized supported - onnx_ops **Issues, questions** diff --git a/docs/onnx_ops.rst b/docs/onnx_ops.rst deleted file mode 100644 index 7f3820b57..000000000 --- a/docs/onnx_ops.rst +++ /dev/null @@ -1,20 +0,0 @@ - -======================== -Available ONNX operators -======================== - -*skl2onnx* maps every ONNX operators into a class -easy to insert into a graph. These operators get -dynamically added and the list depends on the installed -*ONNX* package. The documentation for these operators -can be found on github: `ONNX Operators.md -`_ -and `ONNX-ML Operators -`_. -Associated to `onnxruntime `_, -the mapping makes it easier to easily check the output -of the *ONNX* operators on any data as shown -in example :ref:`l-onnx-operators`. - -.. supported-onnx-ops:: - diff --git a/docs/supported.rst b/docs/supported.rst index 3989f8df1..8c485df68 100644 --- a/docs/supported.rst +++ b/docs/supported.rst @@ -15,6 +15,8 @@ implements *to_onnx* methods. .. contents:: :local: +.. _l-converter-list: + Covered Converters ================== @@ -37,3 +39,20 @@ Pipeline .. autoclass:: skl2onnx.algebra.sklearn_ops.OnnxSklearnFeatureUnion :members: to_onnx, to_onnx_operator, onnx_parser, onnx_shape_calculator, onnx_converter +Available ONNX operators +======================== + +*skl2onnx* maps every ONNX operators into a class +easy to insert into a graph. These operators get +dynamically added and the list depends on the installed +*ONNX* package. The documentation for these operators +can be found on github: `ONNX Operators.md +`_ +and `ONNX-ML Operators +`_. +Associated to `onnxruntime `_, +the mapping makes it easier to easily check the output +of the *ONNX* operators on any data as shown +in example :ref:`l-onnx-operators`. + +.. supported-onnx-ops:: \ No newline at end of file diff --git a/skl2onnx/algebra/type_helper.py b/skl2onnx/algebra/type_helper.py index 77e25ff01..0ca3aa5cc 100644 --- a/skl2onnx/algebra/type_helper.py +++ b/skl2onnx/algebra/type_helper.py @@ -6,10 +6,13 @@ import numpy as np from ..proto import TensorProto, ValueInfoProto, onnx_proto from ..common._topology import Variable -from ..common.data_types import FloatTensorType, Int64TensorType -from ..common.data_types import StringTensorType -from ..common.data_types import Int32TensorType, DoubleTensorType -from ..common.data_types import BooleanTensorType +from ..common.data_types import ( + BooleanTensorType, + DoubleTensorType, FloatTensorType, + Int64Type, + Int64TensorType, Int32TensorType, + StringTensorType +) def _guess_type_proto(data_type, dims): @@ -39,16 +42,18 @@ def _guess_type(given_type): if isinstance(given_type, np.ndarray): if given_type.dtype == np.float32: return FloatTensorType(given_type.shape) + elif given_type.dtype == np.int32: + return Int32TensorType(given_type.shape) elif given_type.dtype == np.int64: return Int64TensorType(given_type.shape) - elif given_type.dtype == np.str: + elif given_type.dtype == np.str or str(given_type.dtype) in ('