Skip to content

Commit

Permalink
Consider GraphInferenceContext in inference functions: InferenceConte…
Browse files Browse the repository at this point in the history
…xt (#4632)

* Expose GraphInferenceContext in Python interface for inference functions

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* use the same map

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add opset_imports and handle input_types

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* graph_opset_import to clarify

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* fix lint

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* fix black

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add a test

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* make_opsetid

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* replace test with Add in subgraph

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* remove unused

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* use shorter name for opset_imports and ir_version

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
jcwchen committed Nov 18, 2022
1 parent 466edb7 commit cd6e5db
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 57 deletions.
16 changes: 13 additions & 3 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction(
const py::bytes& nodeBytes,
std::unordered_map<std::string, py::bytes> valueTypesByNameBytes,
std::unordered_map<std::string, py::bytes> inputDataByNameBytes,
std::unordered_map<std::string, py::bytes> inputSparseDataByNameBytes) {
std::unordered_map<std::string, py::bytes> inputSparseDataByNameBytes,
std::unordered_map<std::string, int> opsetImports,
const int irVersion) {
NodeProto node{};
ParseProtoFromPyBytes(&node, nodeBytes);
// Early fail if node is badly defined - may throw ValidationError
Expand All @@ -68,9 +70,15 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction(
const auto& valueTypes = ParseProtoFromBytesMap<TypeProto>(valueTypesByNameBytes);
const auto& inputData = ParseProtoFromBytesMap<const TensorProto>(inputDataByNameBytes);
const auto& inputSparseData = ParseProtoFromBytesMap<const SparseTensorProto>(inputSparseDataByNameBytes);
if (opsetImports.empty()) {
opsetImports[schema->domain()] = schema->SinceVersion();
}

shape_inference::GraphInferenceContext graphInferenceContext(
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
// Construct inference context and get results - may throw InferenceError
shape_inference::InferenceContextImpl ctx(node, valueTypes.second, inputData.second, inputSparseData.second);
shape_inference::InferenceContextImpl ctx(
node, valueTypes.second, inputData.second, inputSparseData.second, nullptr, &graphInferenceContext);
schema->GetTypeAndShapeInferenceFunction()(ctx);
// Verify the inference succeeded - may also throw ValidationError
// Note that input types were not validated until now (except that their count was correct)
Expand Down Expand Up @@ -142,7 +150,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
py::arg("nodeBytes"),
py::arg("valueTypesByNameBytes"),
py::arg("inputDataByNameBytes") = std::unordered_map<std::string, py::bytes>{},
py::arg("inputSparseDataByNameBytes") = std::unordered_map<std::string, py::bytes>{})
py::arg("inputSparseDataByNameBytes") = std::unordered_map<std::string, py::bytes>{},
py::arg("opsetImports") = std::unordered_map<std::string, int>{},
py::arg("irVersion") = int(IR_VERSION))
.def(
"get_context_dependent_function",
[](OpSchema* op, const py::bytes& bytes, const std::vector<py::bytes>& input_types_bytes) -> py::bytes {
Expand Down
18 changes: 15 additions & 3 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import onnx
import onnx.onnx_cpp2py_export.shape_inference as C
Expand Down Expand Up @@ -88,18 +88,28 @@ def infer_node_outputs(
input_types: Dict[str, onnx.TypeProto],
input_data: Optional[Dict[str, onnx.TensorProto]] = None,
input_sparse_data: Optional[Dict[str, onnx.SparseTensorProto]] = None,
opset_imports: Optional[List[onnx.OperatorSetIdProto]] = None,
ir_version: int = onnx.IR_VERSION,
) -> Dict[str, onnx.TypeProto]:
if not schema.has_type_and_shape_inference_function: # type: ignore
return {}
if input_data is None:
input_data = {}
if input_sparse_data is None:
input_sparse_data = {}
if opset_imports is None:
passed_opset_imports = {}
else:
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}

# To avoid copying on C++ side, pass only what is needed for this inference call
# catch KeyError if node's input does not exist in input_types
passed_input_types = {
key: input_types[key].SerializeToString() for key in node.input
}
# input_types will also be used as outer_scope_value_types so do not filter by node's input here
for key in input_types:
if key not in passed_input_types:
passed_input_types[key] = input_types[key].SerializeToString()
passed_input_data = {
key: input_data[key].SerializeToString()
for key in node.input
Expand All @@ -116,7 +126,9 @@ def infer_node_outputs(
passed_input_types,
passed_input_data,
passed_sparse_input_data,
)
passed_opset_imports,
ir_version,
) # type: ignore[call-arg]
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}


Expand Down
149 changes: 98 additions & 51 deletions onnx/test/inference_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,123 +3,119 @@

import numpy as np

import onnx
import onnx.numpy_helper
import onnx.shape_inference
from onnx import TensorProto, TypeProto
from onnx.checker import ValidationError
from onnx.defs import OpSchema, get_all_schemas_with_history, get_schema
from onnx.helper import (
make_graph,
make_node,
make_opsetid,
make_tensor_type_proto,
make_tensor_value_info,
)
from onnx.numpy_helper import from_array
from onnx.shape_inference import InferenceError, infer_node_outputs

ADD_SCHEMA = max(
(
s
for s in onnx.defs.get_all_schemas_with_history()
if s.name == "Add" and s.domain == ""
),
(s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
key=lambda s: s.since_version,
)
RESHAPE_SCHEMA = max(
(
s
for s in onnx.defs.get_all_schemas_with_history()
for s in get_all_schemas_with_history()
if s.name == "Reshape" and s.domain == ""
),
key=lambda s: s.since_version,
)

_tensor = onnx.helper.make_tensor_type_proto


def _to_tensor_types(
tensor_types: Dict[str, Tuple[int, Tuple[Union[int, str, None], ...]]]
) -> Dict[str, onnx.TypeProto]:
return {
key: onnx.helper.make_tensor_type_proto(*value)
for key, value in tensor_types.items()
}
) -> Dict[str, TypeProto]:
return {key: make_tensor_type_proto(*value) for key, value in tensor_types.items()}


def _run_case(
schema: onnx.defs.OpSchema,
schema: OpSchema,
input_names: List[str],
output_names: List[str],
input_types: Dict[str, onnx.TypeProto],
input_types: Dict[str, TypeProto],
input_data: Optional[Dict[str, np.ndarray]] = None,
) -> Dict[str, onnx.TypeProto]:
) -> Dict[str, TypeProto]:
if input_data is None:
input_data = {}
return onnx.shape_inference.infer_node_outputs(
return infer_node_outputs(
schema,
onnx.helper.make_node(
schema.name, input_names, output_names, domain=schema.domain
),
make_node(schema.name, input_names, output_names, domain=schema.domain),
input_types,
{key: onnx.numpy_helper.from_array(arr) for key, arr in input_data.items()},
{key: from_array(arr) for key, arr in input_data.items()},
)


class TestInferenceFunctionCall(unittest.TestCase):
def test_add_inference(self) -> None:
cases = [
(
{"A": (onnx.TensorProto.FLOAT, ()), "B": (onnx.TensorProto.FLOAT, ())},
{"C": (onnx.TensorProto.FLOAT, ())},
{"A": (TensorProto.FLOAT, ()), "B": (TensorProto.FLOAT, ())},
{"C": (TensorProto.FLOAT, ())},
),
(
{
"A": (onnx.TensorProto.FLOAT, (None, 2)),
"B": (onnx.TensorProto.FLOAT, (2,)),
"A": (TensorProto.FLOAT, (None, 2)),
"B": (TensorProto.FLOAT, (2,)),
},
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
{"C": (TensorProto.FLOAT, (None, 2))},
),
(
{
"A": (onnx.TensorProto.FLOAT, (None, 2)),
"B": (onnx.TensorProto.FLOAT, (1, 2)),
"A": (TensorProto.FLOAT, (None, 2)),
"B": (TensorProto.FLOAT, (1, 2)),
},
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
{"C": (TensorProto.FLOAT, (None, 2))},
),
(
{
"A": (onnx.TensorProto.DOUBLE, ("n", "m")),
"B": (onnx.TensorProto.DOUBLE, (1, "n", "m")),
"A": (TensorProto.DOUBLE, ("n", "m")),
"B": (TensorProto.DOUBLE, (1, "n", "m")),
},
{"C": (onnx.TensorProto.DOUBLE, (1, "n", "m"))},
{"C": (TensorProto.DOUBLE, (1, "n", "m"))},
),
(
{
"A": (onnx.TensorProto.FLOAT, ("x", 2)),
"B": (onnx.TensorProto.FLOAT, ("y", 2)),
"A": (TensorProto.FLOAT, ("x", 2)),
"B": (TensorProto.FLOAT, ("y", 2)),
},
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
{"C": (TensorProto.FLOAT, (None, 2))},
),
]
for ins, outs in cases:
assert _run_case(ADD_SCHEMA, ["A", "B"], ["C"], _to_tensor_types(ins)) == _to_tensor_types(outs) # type: ignore

def test_add_inference_raises_errors(self) -> None:
with self.assertRaises(onnx.checker.ValidationError):
with self.assertRaises(ValidationError):
_run_case(
ADD_SCHEMA,
["A"],
["C"],
_to_tensor_types({"A": (onnx.TensorProto.FLOAT, (3, 4))}),
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
)
with self.assertRaises(onnx.checker.ValidationError):
with self.assertRaises(ValidationError):
_run_case(
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types(
{"A": (onnx.TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}
),
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}),
)
with self.assertRaises(onnx.shape_inference.InferenceError):
with self.assertRaises(InferenceError):
_run_case(
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types(
{
"A": (onnx.TensorProto.FLOAT, (2, 4)),
"B": (onnx.TensorProto.FLOAT, (3, 4)),
"A": (TensorProto.FLOAT, (2, 4)),
"B": (TensorProto.FLOAT, (3, 4)),
}
),
)
Expand All @@ -128,7 +124,7 @@ def test_add_inference_raises_errors(self) -> None:
ADD_SCHEMA,
["A", "B"],
["C"],
_to_tensor_types({"A": (onnx.TensorProto.FLOAT, (3, 4))}),
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
)

def test_reshape_inference(self) -> None:
Expand All @@ -138,12 +134,63 @@ def test_reshape_inference(self) -> None:
["y"],
_to_tensor_types(
{
"x": (onnx.TensorProto.FLOAT, (5, 4)),
"t": (onnx.TensorProto.INT64, (3,)),
"x": (TensorProto.FLOAT, (5, 4)),
"t": (TensorProto.INT64, (3,)),
}
),
{"t": np.array([2, 2, 5], dtype=np.int64)},
) == _to_tensor_types({"y": (onnx.TensorProto.FLOAT, (2, 2, 5))})
) == _to_tensor_types({"y": (TensorProto.FLOAT, (2, 2, 5))})

def test_scan_inference_with_subgraph(self) -> None:
seq_len = "sequence"
input_size = 2
loop_state_size = 3

input_value_infos = [
make_tensor_value_info("loop_state_in", TensorProto.UNDEFINED, None),
make_tensor_value_info("input", TensorProto.UNDEFINED, None),
make_tensor_value_info("outer", TensorProto.UNDEFINED, None),
]
output_value_infos = [
make_tensor_value_info("loop_state_out", TensorProto.UNDEFINED, None),
make_tensor_value_info("output", TensorProto.FLOAT, (seq_len, input_size)),
]

subgraph = make_graph(
[
make_node("Identity", ["loop_state_in"], ["loop_state_out"]),
make_node("Add", ["input", "outer"], ["output"]),
],
"subgraph",
input_value_infos,
output_value_infos,
)

assert infer_node_outputs(
get_schema("Scan", 9),
make_node(
"Scan",
["loop_state_orig", "scan_input", "scan_outer"],
["loop_state_final", "scan_output"],
num_scan_inputs=1,
body=subgraph,
),
_to_tensor_types(
{
"loop_state_orig": (TensorProto.FLOAT, (loop_state_size,)),
"scan_input": (TensorProto.FLOAT, (seq_len, input_size)),
"scan_outer": (TensorProto.FLOAT, (input_size,)),
}
),
# Same as default value in Scan-9
opset_imports=[make_opsetid("", 9)],
ir_version=4,
) == _to_tensor_types(
{
"loop_state_final": (TensorProto.FLOAT, (loop_state_size,)),
"scan_output": (TensorProto.FLOAT, (seq_len, input_size)),
}
)


if __name__ == "__main__":
Expand Down

0 comments on commit cd6e5db

Please sign in to comment.