Skip to content

Commit

Permalink
Initial PR
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenBao committed Feb 22, 2024
1 parent 88f1a9c commit c10f372
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 17 deletions.
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main_export(
legacy: bool = False,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
dynamo: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -162,6 +163,8 @@ def main_export(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
dynamo (bool, default to `False):
PyTorch-specific argument. If `True`, export with the new Dynamo ONNX Exporter introduced in PyTorch 2+.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -368,6 +371,7 @@ def main_export(
task=task,
use_subprocess=use_subprocess,
do_constant_folding=do_constant_folding,
dynamo=dynamo,
**kwargs_shapes,
)

Expand Down Expand Up @@ -404,6 +408,7 @@ def main():
library_name=args.library_name,
legacy=args.legacy,
do_constant_folding=not args.no_constant_folding,
dynamo=args.dynamo,
**input_shapes,
)

Expand Down
70 changes: 55 additions & 15 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def export_pytorch(
input_shapes: Optional[Dict] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
dynamo: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -513,6 +514,8 @@ def export_pytorch(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
dynamo (bool, default to `False):
PyTorch-specific argument. If `True`, export with the new Dynamo ONNX Exporter introduced in PyTorch 2+.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -568,22 +571,47 @@ def remap(value):
output_names = list(config.outputs.keys())

if no_dynamic_axes:
dynamix_axes = None
dynamic_axes = None
else:
dynamix_axes = dict(chain(inputs.items(), config.outputs.items()))

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamix_axes,
do_constant_folding=do_constant_folding,
opset_version=opset,
)
dynamic_axes = dict(chain(inputs.items(), config.outputs.items()))

if not dynamo:
# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
model,
(dummy_inputs,),
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=do_constant_folding,
opset_version=opset,
)
else:
try:
from torch.onnx import dynamo_export as dynamo_onnx_export
except ImportError:
raise MinimumVersionError(
"The dynamo export feature is only available in PyTorch >= 2.1."
)
export_options = torch.onnx.ExportOptions(
dynamic_shapes=dynamic_axes is not None,
)
if opset != 18:
logger.warning(
"Dynamo ONNX export only supports opset 18 for now. The opset will be set to 18."
)
# Args and kwargs are supported natively in dynamo onnx export.
onnx_program = dynamo_onnx_export(
model,
export_options = export_options,
**dummy_inputs,
)
# TODO: Much of the later code performing external data clean-up is
# unnecessary for dynamo onnx export. ModelProto is directly accessible
# from onnx_program.
onnx_program.save(output.as_posix())

# check if external data was exported
# TODO: this is quite inefficient as we load in memory if models are <2GB without external data
Expand Down Expand Up @@ -706,6 +734,7 @@ def export_models(
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
dynamo: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Expand Down Expand Up @@ -736,6 +765,8 @@ def export_models(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
dynamo (bool, default to `False):
PyTorch-specific argument. If `True`, export with the new Dynamo ONNX Exporter introduced in PyTorch 2+.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -771,6 +802,7 @@ def export_models(
dtype=dtype,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
dynamo=dynamo,
model_kwargs=model_kwargs,
)
)
Expand All @@ -790,6 +822,7 @@ def export(
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
dynamo: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand Down Expand Up @@ -817,6 +850,8 @@ def export(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
dynamo (bool, default to `False):
PyTorch-specific argument. If `True`, export with the new Dynamo ONNX Exporter introduced in PyTorch 2+.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -874,6 +909,7 @@ def export(
input_shapes=input_shapes,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
dynamo=dynamo,
model_kwargs=model_kwargs,
)

Expand Down Expand Up @@ -918,6 +954,7 @@ def onnx_export_from_model(
task: Optional[str] = None,
use_subprocess: bool = False,
do_constant_folding: bool = True,
dynamo: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -973,6 +1010,8 @@ def onnx_export_from_model(
If True, disables the use of dynamic axes during ONNX export.
do_constant_folding (bool, defaults to `True`):
PyTorch-specific argument. If `True`, the PyTorch ONNX export will fold constants into adjacent nodes, if possible.
dynamo (bool, default to `False):
PyTorch-specific argument. If `True`, export with the new Dynamo ONNX Exporter introduced in PyTorch 2+.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.
Expand Down Expand Up @@ -1159,6 +1198,7 @@ def onnx_export_from_model(
dtype=float_dtype,
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
dynamo=dynamo,
model_kwargs=model_kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
"protobuf>=3.20.1",
"accelerate", # ORTTrainer requires it.
],
"exporters": ["onnx", "onnxruntime", "timm"],
"exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"],
"exporters": ["onnx", "onnxruntime", "onnxscript", "timm"],
"exporters-gpu": ["onnx", "onnxruntime-gpu", "onnxscript", "timm"],
"exporters-tf": [
"tensorflow>=2.4,<=2.12.1",
"tf2onnx",
Expand Down

0 comments on commit c10f372

Please sign in to comment.