Skip to content

Commit

Permalink
set input_shape first to enable more shape inference chances
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed May 31, 2021
1 parent feda919 commit 99f544e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion onnxsim/__main__.py
Expand Up @@ -38,7 +38,7 @@ def main():
raise RuntimeError(
'Please pass "--input-shape" argument for generating random input and checking equality. Run "python3 -m onnxsim -h" for details.')
if args.input_shape is not None and not args.dynamic_input_shape:
print("Note: The input shape of the simplified model will be overwritten by the value of '--input--shape' argument. Pass '--dynamic-input-shape' if it is not what you want. Run 'python3 -m onnxsim -h' for details.")
print("Note: The input shape of the simplified model will be overwritten by the value of '--input-shape' argument. Pass '--dynamic-input-shape' if it is not what you want. Run 'python3 -m onnxsim -h' for details.")
input_shapes = dict()
if args.input_shape is not None:
for x in args.input_shape:
Expand Down
21 changes: 11 additions & 10 deletions onnxsim/onnx_simplifier.py
Expand Up @@ -345,7 +345,7 @@ def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Tensors):
return [node for node in const_nodes if node.output[0] in res]


def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapesWithOptionalKey) -> TensorShapes:
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapesWithOptionalKey, dynamic_input_shape: bool) -> TensorShapes:
input_names = get_input_names(model)
if None in input_shapes:
if len(input_names) == 1:
Expand All @@ -358,6 +358,15 @@ def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorSh
if x not in input_names:
raise RuntimeError(
'The model doesn\'t have input named "{}"'.format(x))

# Overwrite model input shape
if not dynamic_input_shape:
for name, input_shape in input_shapes.items():
for ipt in model.graph.input:
if ipt.name == name:
for i, dim in enumerate(ipt.type.tensor_type.shape.dim):
dim.dim_value = input_shape[i]

return input_shapes # type: ignore


Expand Down Expand Up @@ -452,7 +461,7 @@ def simplify(model: Union[str, onnx.ModelProto],
elif input_name not in input_shapes:
input_shapes[input_name] = shape

updated_input_shapes = check_and_update_input_shapes(model, input_shapes)
updated_input_shapes = check_and_update_input_shapes(model, input_shapes, dynamic_input_shape)

def infer_shapes_and_optimize(model: onnx.ModelProto) -> onnx.ModelProto:
def infer_shapes_if_applicable(model: onnx.ModelProto) -> onnx.ModelProto:
Expand Down Expand Up @@ -482,14 +491,6 @@ def constant_folding(model: onnx.ModelProto) -> onnx.ModelProto:

model = fixed_point(model, infer_shapes_and_optimize, constant_folding)

# Overwrite model input shape
if not dynamic_input_shape:
for name, input_shape in updated_input_shapes.items():
for ipt in model.graph.input:
if ipt.name == name:
for i, dim in enumerate(ipt.type.tensor_type.shape.dim):
dim.dim_value = input_shape[i]

check_ok = check(model_ori, model, check_n,
input_shapes=updated_input_shapes)

Expand Down

0 comments on commit 99f544e

Please sign in to comment.