Skip to content

Commit

Permalink
[ONNX] Enables data propogation for onnx shape inference (#3280)
Browse files Browse the repository at this point in the history
This small change seems to dramatically improve shape inference for
complex models, and consequently, improves onnx importer reliability.
  • Loading branch information
zjgarvey committed May 8, 2024
1 parent 346a536 commit 0abc586
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
def import_onnx(contents):
# Import the ONNX model proto from the file contents:
raw_model = onnx.load_from_string(contents)
# since it does not affect current e2e tests, data_prop is left false here
model_proto = onnx.shape_inference.infer_shapes(raw_model)

# Import the ONNX module into an MLIR module:
Expand Down
15 changes: 13 additions & 2 deletions python/torch_mlir/tools/import_onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
onnx.checker.check_model(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(
raw_model, data_prop=args.data_prop
)
return inferred_model
except ValueError:
pass
Expand All @@ -103,7 +105,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
temp_inferred_file = temp_dir / "inferred.onnx"
onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file)
onnx.shape_inference.infer_shapes_path(
args.input_file, temp_inferred_file, data_prop=args.data_prop
)

# Sanity check the shape-inferred model to be sure we have a good model
# for the importer. This call uses the file-based method, as the
Expand Down Expand Up @@ -138,6 +142,13 @@ def parse_arguments(argv=None) -> argparse.Namespace:
action="store_true",
help="Disable verification prior to printing",
)
parser.add_argument(
"--data-prop",
dest="data_prop",
default=True,
action=argparse.BooleanOptionalAction,
help="Toggle data propogation for onnx shape inference",
)
parser.add_argument(
"--keep-temps", action="store_true", help="Keep intermediate files"
)
Expand Down

0 comments on commit 0abc586

Please sign in to comment.