Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set schema inference function in python #5940

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
branch = master
[submodule "third_party/pybind11_protobuf"]
path = third_party/pybind11_protobuf
url = https://github.com/pybind/pybind11_protobuf.git
23 changes: 21 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,25 @@ if(BUILD_ONNX_PYTHON)
endif()
endif()

add_library(onnx_cpp2py_export MODULE "${ONNX_ROOT}/onnx/cpp2py_export.cc")
set(SOURCE_FILE "${ONNX_ROOT}/onnx/cpp2py_export.cc")

# In order to keep the versions of the third-party libraries consistent,
# use the source files directly.
set(PYBIND11_PROTOBUF_DIR "${ONNX_ROOT}/third_party/pybind11_protobuf/")
if(EXISTS "${PYBIND11_PROTOBUF_DIR}")
set(SOURCE_FILE
"${SOURCE_FILE}"
${PYBIND11_PROTOBUF_DIR}/pybind11_protobuf/native_proto_caster.h
${PYBIND11_PROTOBUF_DIR}/pybind11_protobuf/enum_type_caster.h
${PYBIND11_PROTOBUF_DIR}/pybind11_protobuf/proto_cast_util.cc
${PYBIND11_PROTOBUF_DIR}/pybind11_protobuf/proto_cast_util.h
${PYBIND11_PROTOBUF_DIR}/pybind11_protobuf/proto_caster_impl.h
)
else()
message(FATAL_ERROR "cannot find pybind11_protobuf")
endif()

add_library(onnx_cpp2py_export MODULE "${SOURCE_FILE}")
set_target_properties(onnx_cpp2py_export PROPERTIES PREFIX "")
set_target_properties(onnx_cpp2py_export
PROPERTIES COMPILE_FLAGS "-fvisibility=hidden")
Expand All @@ -565,7 +583,8 @@ if(BUILD_ONNX_PYTHON)

target_include_directories(onnx_cpp2py_export PUBLIC
"${pybind11_INCLUDE_DIRS}"
"${PYTHON_INCLUDE_DIRS}")
"${PYTHON_INCLUDE_DIRS}"
"${PYBIND11_PROTOBUF_DIR}")

if(APPLE)
set_target_properties(onnx_cpp2py_export
Expand Down
40 changes: 39 additions & 1 deletion onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11_protobuf/native_proto_caster.h>

#include <climits>
#include <limits>
Expand All @@ -15,6 +17,7 @@
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "onnx/inliner/inliner.h"
#include "onnx/py_utils.h"
#include "onnx/shape_inference/implementation.h"
Expand Down Expand Up @@ -104,6 +107,8 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction(
}

PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
pybind11_protobuf::ImportNativeProtoCasters();

onnx_cpp2py_export.doc() = "Python interface to ONNX";

onnx_cpp2py_export.attr("ONNX_ML") = py::bool_(
Expand All @@ -114,6 +119,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
#endif // ONNX_ML
);

// Avoid Segmentation fault if we not free the python function in Custom Schema
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); }));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a more graceful approach would be to collect the schemas registered from Python and deregister them during cleanup. However, I'm not sure if it's worth the effort. In most cases, invoking cleanup implies that Python is exiting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify when this gets invoked?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the segfault is caused by:
The python object (the inference function in custom schema) need destroyed before the python interpreter is destroyed. The static container within the schema factory is destroyed after main function and before the interpreter. Therefore, we need to manually destroy the Python object.

About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors


// Submodule `schema`
auto defs = onnx_cpp2py_export.def_submodule("defs");
defs.doc() = "Schema submodule";
Expand Down Expand Up @@ -394,7 +402,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
func_proto.SerializeToString(&func_bytes);
}
return py::bytes(func_bytes);
});
})
.def(
"set_type_and_shape_inference_function",
[](OpSchema& op, const std::function<void(InferenceContext*)>& func) -> OpSchema& {
auto wrapper = [=](InferenceContext& ctx) { func(&ctx); };
return op.TypeAndShapeInferenceFunction(wrapper);
},
py::return_value_policy::reference_internal);

defs.def(
"has_schema",
Expand Down Expand Up @@ -625,6 +640,29 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");

py::class_<InferenceContext> inference_ctx(shape_inference, "InferenceContext", "Inference context");

inference_ctx.def("get_attribute", &InferenceContext::getAttribute);
inference_ctx.def("get_num_inputs", &InferenceContext::getNumInputs);
inference_ctx.def("has_input", &InferenceContext::hasInput);
inference_ctx.def("get_input_type", &InferenceContext::getInputType);
inference_ctx.def("get_input_data", &InferenceContext::getInputData);
inference_ctx.def("get_input_sparse_data", &InferenceContext::getInputSparseData);
inference_ctx.def("get_symbolic_input", &InferenceContext::getSymbolicInput);
inference_ctx.def("get_graph_attribute_inferencer", &InferenceContext::getGraphAttributeInferencer);
inference_ctx.def("get_num_outputs", &InferenceContext::getNumOutputs);
inference_ctx.def("get_output_type", &InferenceContext::getOutputType, py::return_value_policy::reference);
inference_ctx.def("set_output_type", [](InferenceContext& self, size_t idx, const TypeProto& src) {
auto* dst = self.getOutputType(idx);
if (dst == nullptr || dst == &src) {
return;
}
dst->CopyFrom(src);
});

py::class_<GraphInferencer> graph_inferencer(shape_inference, "GraphInferencer", "Graph Inferencer");
graph_inferencer.def("do_inferencing", &GraphInferencer::doInferencing);

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
Expand Down
1 change: 1 addition & 0 deletions onnx/onnx_cpp2py_export/defs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OpSchema:
) -> dict[str, bytes]: ...
@property
def function_body(self) -> FunctionProto: ...
def set_type_and_shape_inference_function(self, func) -> None: ...

class TypeConstraintParam:
def __init__(
Expand Down
19 changes: 19 additions & 0 deletions onnx/onnx_cpp2py_export/shape_inference.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
from types import List

from onnx import AttributeProto, TypeProto, TensorProto, SparseTensorProto, TensorShapeProto

class InferenceError(Exception): ...

class GraphInferencer:
def do_inferencing(self, input_types: List[TypeProto], input_data: List[TensorProto]) -> List[TypeProto]: ...

class InferenceContext:
def get_num_inputs(self) -> int: ...
def has_input(self, idx: int) -> bool: ...
def get_num_outputs(self) -> int: ...
def get_attribute(self, name: str) -> AttributeProto: ...
def get_input_type(self, idx: int) -> TypeProto: ...
def get_input_data(self, idx: int) -> TensorProto: ...
def get_input_sparse_data(self, idx: int) -> SparseTensorProto: ...
def get_symbolic_input(self, idx: int) -> TensorShapeProto: ...
def get_graph_attribute_inferencer(self) -> GraphInferencer: ...
def get_output_type(self, idx: int) -> TypeProto: ...
def set_output_type(self, idx: int, type: TypeProto) -> None: ...

def infer_shapes(
b: bytes, check_type: bool, strict_mode: bool, data_prop: bool
) -> bytes: ...
Expand Down
3 changes: 3 additions & 0 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto

GraphInferencer = C.GraphInferencer
InferenceContext = C.InferenceContext


def infer_shapes(
model: ModelProto | bytes,
Expand Down
97 changes: 97 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10172,6 +10172,103 @@ def test_check_type_when_schema_has_empty_io(self):
op_schema.name, op_schema.since_version, op_schema.domain
)

def test_custom_schema_shape_inference(self) -> None:
# CustomOp schema:
# attrs:
# out_len: [L0, L1, ...]
# inputs:
# a[N, La]
# b[N, Lb]
# outputs:
# out0[N, La * Lb, L0]
# out1[N, La * Lb, L1]
# ...
N = 3
La = 32
Lb = 64
out_len = [1, 2]
outs = [f"out{i}" for i in range(len(out_len))]
graph = self._make_graph(
[
("a", TensorProto.FLOAT, (N, La)),
("b", TensorProto.FLOAT, (N, Lb)),
],
[make_node("CustomOp", ["a", "b"], outs, out_len=out_len)],
[],
)
with self.assertRaises(onnx.checker.ValidationError):
self._assert_inferred(
graph,
[
make_tensor_value_info(
f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li)
)
for i, Li in enumerate(out_len)
],
)

schema = OpSchema(
"CustomOp",
"",
1,
inputs=[
defs.OpSchema.FormalParameter("a", "float"),
defs.OpSchema.FormalParameter("b", "float"),
],
outputs=[
defs.OpSchema.FormalParameter(
"out", "float", param_option=OpSchema.FormalParameterOption.Variadic
),
],
attributes=[
defs.OpSchema.Attribute("out_len", defs.OpSchema.AttrType.INTS)
],
)

def func(ctx: onnx.shape_inference.InferenceContext):
def parse_tensor_input(t: TypeProto):
assert isinstance(t, TypeProto)
return (
t.tensor_type.elem_type,
[
d.dim_value if d.HasField("dim_value") else None
for d in t.tensor_type.shape.dim
],
)

assert ctx.get_num_inputs() == 2
in0 = ctx.get_input_type(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you are using pybind11's protobuf conversion (between C++ and python) above? Should this example be updated to match that? I guess it is still doing the same thing, but at least we don't need extra-code to manage the serialization/deserialization?

It seems ok to me, in the sense this is what we currently have for Python - C++ inter-op anyway (that is, serialize/deserialize). Until we come up with a better solution, this seems fine.

in1 = ctx.get_input_type(1)
in0_type, in0_shape = parse_tensor_input(in0)
in1_type, in1_shape = parse_tensor_input(in1)
assert in0_type == in1_type == TensorProto.FLOAT
assert len(in0_shape) == len(in1_shape) == 2
assert in0_shape[0] == in1_shape[0]
N, La = in0_shape
_, Lb = in1_shape
attr = ctx.get_attribute("out_len")
out_len = attr.ints
assert len(out_len) == ctx.get_num_outputs()
for i in range(ctx.get_num_outputs()):
out = ctx.get_output_type(i)
out.tensor_type.elem_type = in0_type
out.tensor_type.shape.dim.add().dim_value = N
out.tensor_type.shape.dim.add().dim_value = La * Lb
out.tensor_type.shape.dim.add().dim_value = out_len[i]
ctx.set_output_type(i, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here, we should avoid serialization with something like set_output_type_and_shape(in0_type, (N, La*Lb, out_lin[i]). The type is created on C++ side, there is no serialization and it would be more efficient.


schema.set_type_and_shape_inference_function(func)
onnx.defs.register_schema(schema)

self._assert_inferred(
graph,
[
make_tensor_value_info(f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li))
for i, Li in enumerate(out_len)
],
)
onnx.defs.deregister_schema(schema.name, schema.since_version, schema.domain)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions third_party/pybind11_protobuf
Submodule pybind11_protobuf added at 84653a
Loading