Skip to content

Commit

Permalink
Allow to check a serialized model (onnx#3403)
Browse files Browse the repository at this point in the history
Avoids many serialization of a large model (when e.g. running it with
onnxruntime)

Allow to run shape inference on a serialized model

Signed-off-by: IceTDrinker <49040125+IceTDrinker@users.noreply.github.com>
Signed-off-by: neginraoof <neginmr@utexas.edu>
  • Loading branch information
IceTDrinker authored and neginraoof committed May 14, 2021
1 parent 046295e commit ba19d8c
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
6 changes: 3 additions & 3 deletions onnx/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import onnx.defs
from google.protobuf.message import Message
from typing import TypeVar, Callable, Any, Type, cast, Union, Text
from six import string_types
from six import string_types, binary_type
import onnx.shape_inference
import sys

Expand Down Expand Up @@ -91,16 +91,16 @@ def check_sparse_tensor(sparse, ctx=DEFAULT_CONTEXT): # type: (SparseTensorProt
C.check_sparse_tensor(sparse.SerializeToString(), ctx)


def check_model(model, full_check=False): # type: (Union[ModelProto, Text], bool) -> None
def check_model(model, full_check=False): # type: (Union[ModelProto, Text, bytes], bool) -> None
# If model is a path instead of ModelProto
if isinstance(model, string_types):
C.check_model_path(model)
if full_check:
onnx.shape_inference.infer_shapes_path(model, check_type=True, strict_mode=True)
else:
protobuf_string = model if isinstance(model, binary_type) else model.SerializeToString()
# If the protobuf is larger than 2GB,
# remind users should use the model path to check
protobuf_string = model.SerializeToString()
if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF:
raise ValueError('This protobuf of onnx model is too large (>2GB). Call check_model with model path instead.')
C.check_model(protobuf_string)
Expand Down
16 changes: 8 additions & 8 deletions onnx/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import onnx
import onnx.onnx_cpp2py_export.shape_inference as C
from onnx import ModelProto
from six import string_types
from typing import Text
from six import string_types, binary_type
from typing import Text, Union

"""Apply shape inference to the provided ModelProto.
Expand All @@ -24,23 +24,23 @@
bug in shape inference), and the result is unspecified.
Arguments:
input (Union[ModelProto, Text], Text, bool) -> ModelProto
input (Union[ModelProto, Text, bytes], Text, bool) -> ModelProto
Return:
return (ModelProto) model with inferred shape information
"""


def infer_shapes(model, check_type=False, strict_mode=False): # type: (ModelProto, bool, bool) -> ModelProto
if isinstance(model, ModelProto):
model_str = model.SerializeToString()
def infer_shapes(model, check_type=False, strict_mode=False): # type: (Union[ModelProto, bytes], bool, bool) -> ModelProto
if isinstance(model, (ModelProto, binary_type)):
model_str = model if isinstance(model, binary_type) else model.SerializeToString()
inferred_model_str = C.infer_shapes(model_str, check_type, strict_mode)
return onnx.load_from_string(inferred_model_str)
elif isinstance(model, string_types):
raise TypeError('infer_shapes only accepts ModelProto,'
raise TypeError('infer_shapes only accepts ModelProto or bytes,'
'you can use infer_shapes_path for the model path (String).')
else:
raise TypeError('infer_shapes only accepts ModelProto, '
raise TypeError('infer_shapes only accepts ModelProto or bytes, '
'incorrect type: {}'.format(type(model)))


Expand Down
12 changes: 12 additions & 0 deletions onnx/test/checker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ def test_check_model(self): # type: () -> None

checker.check_model(model)

def test_check_serialized_model(self): # type: () -> None
node = helper.make_node(
"Relu", ["X"], ["Y"], name="test")
graph = helper.make_graph(
[node],
"test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2])],
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2])])
model = helper.make_model(graph, producer_name='test')

checker.check_model(model.SerializeToString())

def test_check_old_model(self): # type: () -> None
node = helper.make_node(
"Pad", ["X"], ["Y"], paddings=(0, 0, 0, 0))
Expand Down
9 changes: 9 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3421,6 +3421,15 @@ def test_infer_initializer_input_consistency_differnt_rank(self): # type: () ->
# Inferred shape and existing shape differ in rank: (3) vs (2)
self.assertRaises(onnx.shape_inference.InferenceError, onnx.shape_inference.infer_shapes, original_model, strict_mode=True)

def test_infer_initializer_input_consistency_all_none_serialized(self): # type: () -> None
# Reuse test_infer_initializer_input_consistency_all_none test case and check with
# Serialized model
initializer_shape = (8, 7)
input_shape = (None, None) # accepatble
original_model = self.prepare_input_initializer_tensors(initializer_shape, input_shape)

onnx.shape_inference.infer_shapes(original_model.SerializeToString(), strict_mode=True)

def test_trilu_upper(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5)),
Expand Down

0 comments on commit ba19d8c

Please sign in to comment.