Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support set schema inference function in python #5940

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <pybind11/functional.h>

Check warning on line 5 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: 'pybind11/functional.h' file not found [clang-diagnostic-error] ```cpp #include <pybind11/functional.h> ^ ```
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -394,6 +395,12 @@
func_proto.SerializeToString(&func_bytes);
}
return py::bytes(func_bytes);
})
.def(
"set_type_and_shape_inference_function",
[](OpSchema* op, const std::function<void(InferenceContext*)>& func) {
auto wrapper = [=](InferenceContext& ctx) { func(&ctx); };
return op->TypeAndShapeInferenceFunction(wrapper);
});

defs.def(
Expand Down Expand Up @@ -625,6 +632,85 @@
shape_inference.doc() = "Shape Inference submodule";
py::register_exception<InferenceError>(shape_inference, "InferenceError");

py::class_<InferenceContext> inference_ctx(shape_inference, "InferenceContext", "Inference context");

inference_ctx.def("__get_attribute", [](InferenceContext* ctx, std::string name) {
auto attr = ctx->getAttribute(name);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
std::string data;
attr->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("get_num_inputs", &InferenceContext::getNumInputs);
inference_ctx.def("has_input", &InferenceContext::hasInput);
inference_ctx.def("__get_input_type", [](InferenceContext* ctx, size_t index) {
auto type = ctx->getInputType(index);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
std::string data;
type->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_input_data", [](InferenceContext* ctx, size_t index) {
auto tensor = ctx->getInputData(index);
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
std::string data;
tensor->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_input_sparse_data", [](InferenceContext* ctx, size_t index) {
auto stensor = ctx->getInputSparseData(index);
std::string data;
stensor->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_symbolic_input", [](InferenceContext* ctx, size_t index) {
auto shape = ctx->getSymbolicInput(index);
std::string data;
shape->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__get_graph_attribute_inferencer", &InferenceContext::getGraphAttributeInferencer);
inference_ctx.def("get_num_outputs", &InferenceContext::getNumOutputs);
inference_ctx.def("__get_output_type", [](InferenceContext* ctx, size_t index) {
auto type = ctx->getOutputType(index);
std::string data;
type->SerializeToString(&data);
return py::bytes(data);
});
inference_ctx.def("__set_output_type", [](InferenceContext* ctx, size_t index, py::bytes bytes) {
auto type = ctx->getOutputType(index);
ParseProtoFromPyBytes(type, bytes);
});

py::class_<GraphInferencer> graph_inferencer(shape_inference, "GraphInferencer", "Graph Inferencer");
graph_inferencer.def(
"__do_inferencing",
[](GraphInferencer* inferencer,
const std::vector<py::bytes>& input_types,
const std::vector<py::bytes>& input_data) {
std::vector<TypeProto> type_proto;
std::vector<TensorProto> tensor_proto;
std::vector<const TypeProto*> type_inputs;
std::vector<const TensorProto*> tensor_inputs;
for (const auto& bytes : input_types) {
TypeProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
type_proto.emplace_back(proto);
type_inputs.emplace_back(&type_proto.back());
}
for (const auto& bytes : input_data) {
TensorProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
tensor_proto.emplace_back(proto);
tensor_inputs.emplace_back(&tensor_proto.back());
}
auto ret = inferencer->doInferencing(type_inputs, tensor_inputs);
std::vector<py::bytes> out;
for (const auto& type : ret) {
std::string data;
type->SerializeToString(&data);
out.emplace_back(py::bytes(data));
}
return out;
});

shape_inference.def(
"infer_shapes",
[](const py::bytes& bytes, bool check_type, bool strict_mode, bool data_prop) {
Expand Down
1 change: 1 addition & 0 deletions onnx/onnx_cpp2py_export/defs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class OpSchema:
) -> dict[str, bytes]: ...
@property
def function_body(self) -> FunctionProto: ...
def set_type_and_shape_inference_function(self, func) -> None: ...

class TypeConstraintParam:
def __init__(
Expand Down
32 changes: 32 additions & 0 deletions onnx/onnx_cpp2py_export/shape_inference.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
from types import List

from onnx import AttributeProto, TypeProto, TensorProto, SparseTensorProto, TensorShapeProto

class InferenceError(Exception): ...

class GraphInferencer:
# Impl in cpp (onnx/cpp2py_export.cc)
def __do_inferencing(self, input_types: List[bytes], input_data: List[bytes]) -> List[bytes]: ...
# Impl in py (onnx/shape_inference.py)
def do_inferencing(self, input_types: List[TypeProto], input_data: List[TensorProto]) -> List[TypeProto]: ...

class InferenceContext:
# Impl in cpp (onnx/cpp2py_export.cc)
def get_num_inputs(self) -> int: ...
def has_input(self, idx: int) -> bool: ...
def get_num_outputs(self) -> int: ...
def __get_attribute(self, name: str) -> bytes: ...
def __get_input_type(self, idx: int) -> bytes: ...
def __get_input_data(self, idx: int) -> bytes: ...
def __get_input_sparse_data(self, idx: int) -> bytes: ...
def __get_symbolic_input(self, idx: int) -> bytes: ...
def __get_graph_attribute_inferencer(self) -> GraphInferencer: ...
def __get_output_type(self, idx: int) -> bytes: ...
def __set_output_type(self, idx: int, output: bytes) -> None: ...
# Impl in py (onnx/shape_inference.py)
def get_attribute(self, name: str) -> AttributeProto: ...
def get_input_type(self, idx: int) -> TypeProto: ...
def get_input_data(self, idx: int) -> TensorProto: ...
def get_input_sparse_data(self, idx: int) -> SparseTensorProto: ...
def get_symbolic_input(self, idx: int) -> TensorShapeProto: ...
def get_graph_attribute_inferencer(self) -> GraphInferencer: ...
def get_output_type(self, idx: int) -> TypeProto: ...
def set_output_type(self, idx: int, output: TypeProto) -> None: ...

def infer_shapes(
b: bytes, check_type: bool, strict_mode: bool, data_prop: bool
) -> bytes: ...
Expand Down
53 changes: 52 additions & 1 deletion onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,58 @@

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C # noqa: N812
from onnx import AttributeProto, FunctionProto, ModelProto, TypeProto
from onnx import (
AttributeProto,
FunctionProto,
ModelProto,
SparseTensorProto,
TensorProto,
TensorShapeProto,
TypeProto,
)

GraphInferencer = C.GraphInferencer


def _do_inferencing(
self, input_types: list[TypeProto], input_data: list[TensorProto]
) -> list[TypeProto]:
input_types_bytes = [proto.SerializeToString() for proto in input_types]
input_data_bytes = [proto.SerializeToString() for proto in input_data]
ret = self.__impl.__do_inferencing(input_types_bytes, input_data_bytes)

Check warning on line 35 in onnx/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnx/shape_inference.py#L35

Added line #L35 was not covered by tests
return [TypeProto.FromString(data) for data in ret]


GraphInferencer.do_inferencing = _do_inferencing # type: ignore


def _parse_to_proto(attr, proto_type):
def impl(self, *args, **kwargs):
data = getattr(self, attr)(*args, **kwargs)
return proto_type.FromString(data)

Check warning on line 45 in onnx/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnx/shape_inference.py#L44-L45

Added lines #L44 - L45 were not covered by tests

return impl


InferenceContext = C.InferenceContext
InferenceContext.get_attribute = _parse_to_proto("__get_attribute", AttributeProto) # type: ignore
InferenceContext.get_input_type = _parse_to_proto("__get_input_type", TypeProto) # type: ignore
InferenceContext.get_input_data = _parse_to_proto("__get_input_data", TensorProto) # type: ignore
InferenceContext.get_input_sparse_data = _parse_to_proto( # type: ignore
"__get_input_sparse_data", SparseTensorProto
)
InferenceContext.get_symbolic_input = _parse_to_proto( # type: ignore
"__get_symbolic_input", TensorShapeProto
)
InferenceContext.get_output_type = _parse_to_proto("__get_output_type", TypeProto) # type: ignore


def _op_set_output_type(self, idx: int, output: TypeProto):
data = output.SerializeToString()
self.__set_output_type(idx, data)

Check warning on line 65 in onnx/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnx/shape_inference.py#L64-L65

Added lines #L64 - L65 were not covered by tests


InferenceContext.set_output_type = _op_set_output_type # type: ignore


def infer_shapes(
Expand Down
Loading