-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support set schema inference function in python #5940
base: main
Are you sure you want to change the base?
Changes from all commits
d6d944a
e813a3a
b558aed
095f0fd
a3a00d6
b47c69e
61a5305
f9beedb
3132b4d
3a57147
e3b744b
01f955d
73bed3b
4a6f3c3
6cd35d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,10 @@ | |
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <pybind11/functional.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11_protobuf/native_proto_caster.h> | ||
|
||
#include <climits> | ||
#include <limits> | ||
|
@@ -15,6 +17,7 @@ | |
#include "onnx/defs/parser.h" | ||
#include "onnx/defs/printer.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/defs/shape_inference.h" | ||
#include "onnx/inliner/inliner.h" | ||
#include "onnx/py_utils.h" | ||
#include "onnx/shape_inference/implementation.h" | ||
|
@@ -104,6 +107,8 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction( | |
} | ||
|
||
PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | ||
pybind11_protobuf::ImportNativeProtoCasters(); | ||
|
||
onnx_cpp2py_export.doc() = "Python interface to ONNX"; | ||
|
||
onnx_cpp2py_export.attr("ONNX_ML") = py::bool_( | ||
|
@@ -114,6 +119,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
#endif // ONNX_ML | ||
); | ||
|
||
// Avoid Segmentation fault if we not free the python function in Custom Schema | ||
onnx_cpp2py_export.add_object("_cleanup", py::capsule([] { OpSchemaRegistry::OpSchemaDeregisterAll(); })); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify when this gets invoked? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the segfault is caused by: About '_cleanup' : https://pybind11.readthedocs.io/en/stable/advanced/misc.html#module-destructors |
||
|
||
// Submodule `schema` | ||
auto defs = onnx_cpp2py_export.def_submodule("defs"); | ||
defs.doc() = "Schema submodule"; | ||
|
@@ -394,7 +402,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
func_proto.SerializeToString(&func_bytes); | ||
} | ||
return py::bytes(func_bytes); | ||
}); | ||
}) | ||
.def( | ||
"set_type_and_shape_inference_function", | ||
[](OpSchema& op, const std::function<void(InferenceContext*)>& func) -> OpSchema& { | ||
auto wrapper = [=](InferenceContext& ctx) { func(&ctx); }; | ||
return op.TypeAndShapeInferenceFunction(wrapper); | ||
}, | ||
py::return_value_policy::reference_internal); | ||
|
||
defs.def( | ||
"has_schema", | ||
|
@@ -625,6 +640,29 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
shape_inference.doc() = "Shape Inference submodule"; | ||
py::register_exception<InferenceError>(shape_inference, "InferenceError"); | ||
|
||
py::class_<InferenceContext> inference_ctx(shape_inference, "InferenceContext", "Inference context"); | ||
|
||
inference_ctx.def("get_attribute", &InferenceContext::getAttribute); | ||
inference_ctx.def("get_num_inputs", &InferenceContext::getNumInputs); | ||
inference_ctx.def("has_input", &InferenceContext::hasInput); | ||
inference_ctx.def("get_input_type", &InferenceContext::getInputType); | ||
inference_ctx.def("get_input_data", &InferenceContext::getInputData); | ||
inference_ctx.def("get_input_sparse_data", &InferenceContext::getInputSparseData); | ||
inference_ctx.def("get_symbolic_input", &InferenceContext::getSymbolicInput); | ||
inference_ctx.def("get_graph_attribute_inferencer", &InferenceContext::getGraphAttributeInferencer); | ||
inference_ctx.def("get_num_outputs", &InferenceContext::getNumOutputs); | ||
inference_ctx.def("get_output_type", &InferenceContext::getOutputType, py::return_value_policy::reference); | ||
inference_ctx.def("set_output_type", [](InferenceContext& self, size_t idx, const TypeProto& src) { | ||
auto* dst = self.getOutputType(idx); | ||
if (dst == nullptr || dst == &src) { | ||
return; | ||
} | ||
dst->CopyFrom(src); | ||
}); | ||
|
||
py::class_<GraphInferencer> graph_inferencer(shape_inference, "GraphInferencer", "Graph Inferencer"); | ||
graph_inferencer.def("do_inferencing", &GraphInferencer::doInferencing); | ||
|
||
shape_inference.def( | ||
"infer_shapes", | ||
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10172,6 +10172,103 @@ def test_check_type_when_schema_has_empty_io(self): | |
op_schema.name, op_schema.since_version, op_schema.domain | ||
) | ||
|
||
def test_custom_schema_shape_inference(self) -> None: | ||
# CustomOp schema: | ||
# attrs: | ||
# out_len: [L0, L1, ...] | ||
# inputs: | ||
# a[N, La] | ||
# b[N, Lb] | ||
# outputs: | ||
# out0[N, La * Lb, L0] | ||
# out1[N, La * Lb, L1] | ||
# ... | ||
N = 3 | ||
La = 32 | ||
Lb = 64 | ||
out_len = [1, 2] | ||
outs = [f"out{i}" for i in range(len(out_len))] | ||
graph = self._make_graph( | ||
[ | ||
("a", TensorProto.FLOAT, (N, La)), | ||
("b", TensorProto.FLOAT, (N, Lb)), | ||
], | ||
[make_node("CustomOp", ["a", "b"], outs, out_len=out_len)], | ||
[], | ||
) | ||
with self.assertRaises(onnx.checker.ValidationError): | ||
self._assert_inferred( | ||
graph, | ||
[ | ||
make_tensor_value_info( | ||
f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li) | ||
) | ||
for i, Li in enumerate(out_len) | ||
], | ||
) | ||
|
||
schema = OpSchema( | ||
"CustomOp", | ||
"", | ||
1, | ||
inputs=[ | ||
defs.OpSchema.FormalParameter("a", "float"), | ||
defs.OpSchema.FormalParameter("b", "float"), | ||
], | ||
outputs=[ | ||
defs.OpSchema.FormalParameter( | ||
"out", "float", param_option=OpSchema.FormalParameterOption.Variadic | ||
), | ||
], | ||
attributes=[ | ||
defs.OpSchema.Attribute("out_len", defs.OpSchema.AttrType.INTS) | ||
], | ||
) | ||
|
||
def func(ctx: onnx.shape_inference.InferenceContext): | ||
def parse_tensor_input(t: TypeProto): | ||
assert isinstance(t, TypeProto) | ||
return ( | ||
t.tensor_type.elem_type, | ||
[ | ||
d.dim_value if d.HasField("dim_value") else None | ||
for d in t.tensor_type.shape.dim | ||
], | ||
) | ||
|
||
assert ctx.get_num_inputs() == 2 | ||
in0 = ctx.get_input_type(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern with this is it goes through serialization to access the type information. It is not really efficient. I would change the API so that it does not return a TypeProto but the type and the shape as regular python objects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I agree with Xavier, but I am a bit confused also. I see the method implementation serializes proto values to string and returns them. We could just return a pointer to the C++ Proto object (wrapped as a Python object). Is that your suggestion Xavier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we closely mimic the C++ API design interface, or should we integrate Python's native types for interactions? Utilizing the proto pointer for interactions may require additional codes to bind them to Python (If there is another way please correct me), or we need to include some third-party library. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like you are using pybind11's protobuf conversion (between C++ and python) above? Should this example be updated to match that? I guess it is still doing the same thing, but at least we don't need extra-code to manage the serialization/deserialization? It seems ok to me, in the sense this is what we currently have for Python - C++ inter-op anyway (that is, serialize/deserialize). Until we come up with a better solution, this seems fine. |
||
in1 = ctx.get_input_type(1) | ||
in0_type, in0_shape = parse_tensor_input(in0) | ||
in1_type, in1_shape = parse_tensor_input(in1) | ||
assert in0_type == in1_type == TensorProto.FLOAT | ||
assert len(in0_shape) == len(in1_shape) == 2 | ||
assert in0_shape[0] == in1_shape[0] | ||
N, La = in0_shape | ||
_, Lb = in1_shape | ||
attr = ctx.get_attribute("out_len") | ||
out_len = attr.ints | ||
assert len(out_len) == ctx.get_num_outputs() | ||
for i in range(ctx.get_num_outputs()): | ||
out = ctx.get_output_type(i) | ||
out.tensor_type.elem_type = in0_type | ||
out.tensor_type.shape.dim.add().dim_value = N | ||
out.tensor_type.shape.dim.add().dim_value = La * Lb | ||
out.tensor_type.shape.dim.add().dim_value = out_len[i] | ||
ctx.set_output_type(i, out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here, we should avoid serialization with something like |
||
|
||
schema.set_type_and_shape_inference_function(func) | ||
onnx.defs.register_schema(schema) | ||
|
||
self._assert_inferred( | ||
graph, | ||
[ | ||
make_tensor_value_info(f"out{i}", TensorProto.FLOAT, (N, La * Lb, Li)) | ||
for i, Li in enumerate(out_len) | ||
], | ||
) | ||
onnx.defs.deregister_schema(schema.name, schema.since_version, schema.domain) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a more graceful approach would be to collect the schemas registered from Python and deregister them during cleanup. However, I'm not sure if it's worth the effort. In most cases, invoking cleanup implies that Python is exiting.