From 2f3d2d955ff85b53ac6c47dd462fc1f217a20bae Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Thu, 11 Jan 2024 20:23:03 +0530 Subject: [PATCH 1/7] add no dynamic axes arg --- optimum/commands/export/onnx.py | 4 ++++ optimum/exporters/onnx/__main__.py | 4 ++++ optimum/exporters/onnx/convert.py | 19 ++++++++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index 55f8b9dc1d..9405e38fef 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -150,6 +150,9 @@ def parse_args_onnx(parser): "Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." ), ) + optional_group.add_argument( + "--no_dynamic_axes", action="store_true", help="Disable dynamic axes during ONNX export" + ) input_group = parser.add_argument_group( "Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)." @@ -263,6 +266,7 @@ def run(self): _variant=self.args.variant, library_name=self.args.library_name, legacy=self.args.legacy, + no_dynamic_axes=self.args.no_dynamic_axes, model_kwargs=self.args.model_kwargs, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 91b1b0c3b6..1da7f4e186 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -190,6 +190,7 @@ def main_export( _variant: str = "default", library_name: Optional[str] = None, legacy: bool = False, + no_dynamic_axes: bool = False, **kwargs_shapes, ): """ @@ -270,6 +271,8 @@ def main_export( The library of the model (`"transformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint. legacy (`bool`, defaults to `False`): Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum. + no_dynamic_axes (bool, defaults to `False`): + If True, disables the use of dynamic axes during ONNX export. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -556,6 +559,7 @@ def main_export( input_shapes=input_shapes, device=device, dtype="fp16" if fp16 is True else None, + no_dynamic_axes=no_dynamic_axes, model_kwargs=model_kwargs, ) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index d9ea9c2f09..b29a92df5f 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -487,6 +487,7 @@ def export_pytorch( device: str = "cpu", dtype: Optional["torch.dtype"] = None, input_shapes: Optional[Dict] = None, + no_dynamic_axes: bool = False, model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[str], List[str]]: """ @@ -508,6 +509,8 @@ def export_pytorch( Data type to remap the model inputs to. PyTorch-only. Only `torch.float16` is supported. input_shapes (`Optional[Dict]`, defaults to `None`): If specified, allows to use specific shapes for the example input provided to the ONNX exporter. + no_dynamic_axes (bool, defaults to `False`): + If True, disables the use of dynamic axes during ONNX export. model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export. This argument should be used along the `custom_onnx_config` argument @@ -561,6 +564,12 @@ def remap(value): inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) + + + if no_dynamic_axes: + dynamix_axes = None + else: + dynamix_axes = dict(chain(inputs.items(), config.outputs.items())) # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. @@ -570,7 +579,7 @@ def remap(value): f=output.as_posix(), input_names=input_names, output_names=output_names, - dynamic_axes=dict(chain(inputs.items(), config.outputs.items())), + dynamic_axes=dynamix_axes, do_constant_folding=True, opset_version=opset, ) @@ -694,6 +703,7 @@ def export_models( input_shapes: Optional[Dict] = None, disable_dynamic_axes_fix: Optional[bool] = False, dtype: Optional[str] = None, + no_dynamic_axes: bool = False, model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[List[str]], List[List[str]]]: """ @@ -720,6 +730,8 @@ def export_models( Whether to disable the default dynamic axes fixing. dtype (`Optional[str]`, defaults to `None`): Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported. + no_dynamic_axes (bool, defaults to `False`): + If True, disables the use of dynamic axes during ONNX export. model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export. This argument should be used along the `custom_onnx_config` argument @@ -753,6 +765,7 @@ def export_models( input_shapes=input_shapes, disable_dynamic_axes_fix=disable_dynamic_axes_fix, dtype=dtype, + no_dynamic_axes=no_dynamic_axes, model_kwargs=model_kwargs, ) ) @@ -770,6 +783,7 @@ def export( input_shapes: Optional[Dict] = None, disable_dynamic_axes_fix: Optional[bool] = False, dtype: Optional[str] = None, + no_dynamic_axes: bool = False, model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[str], List[str]]: """ @@ -793,6 +807,8 @@ def export( Whether to disable the default dynamic axes fixing. dtype (`Optional[str]`, defaults to `None`): Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported. + no_dynamic_axes (bool, defaults to `False`): + If True, disables the use of dynamic axes during ONNX export. model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export. This argument should be used along the `custom_onnx_config` argument @@ -855,6 +871,7 @@ def export( device=device, input_shapes=input_shapes, dtype=torch_dtype, + no_dynamic_axes=no_dynamic_axes, model_kwargs=model_kwargs, ) From 7e98acf13f003b827c0be99ccc61516a18f11c76 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 12 Jan 2024 12:59:12 +0530 Subject: [PATCH 2/7] add tests --- optimum/exporters/onnx/__main__.py | 5 +++++ tests/exporters/exporters_utils.py | 8 +++++++ .../exporters/onnx/test_exporters_onnx_cli.py | 22 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 1da7f4e186..96b048db1e 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -447,6 +447,11 @@ def main_export( f" `--task {task_non_past} --monolith`, or `--task {task}` without the monolith argument." ) + if library_name == "timm" and no_dynamic_axes is False and "vovnet" in model_type: + raise ValueError( + f"The export of {model_type} is not supported with dynamic axes. Please pass --no-dynamic-axes to export the model with fixed shapes" + ) + if original_task == "auto": synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) if synonyms_for_task: diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 3efcd3a7eb..cd850d8c95 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -325,3 +325,11 @@ "sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2", "sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1", } + + +PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES = { + "default-timm-config": { + "timm/ese_vovnet39b.ra_in1k": ["image-classification"], + "timm/ese_vovnet19b_dw.ra_in1k": ["image-classification"], + } +} diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 828c808b69..2847424c4e 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -44,6 +44,7 @@ PYTORCH_SENTENCE_TRANSFORMERS_MODEL, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, + PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, ) @@ -179,6 +180,7 @@ def _onnx_export( device: str = "cpu", fp16: bool = False, variant: str = "default", + no_dynamic_axes: bool = False, model_kwargs: Optional[Dict] = None, ): with TemporaryDirectory() as tmpdir: @@ -193,6 +195,7 @@ def _onnx_export( monolith=monolith, no_post_process=no_post_process, _variant=variant, + no_dynamic_axes=no_dynamic_axes, model_kwargs=model_kwargs, ) except MinimumVersionError as e: @@ -258,6 +261,25 @@ def test_exporters_cli_pytorch_cpu_timm( ): self._onnx_export(model_name, task, monolith, no_post_process, variant=variant) + @parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, library_name="timm")) + @require_torch + @require_vision + @require_timm + # @slow + @pytest.mark.timm_test + # @pytest.mark.run_slow + def test_exporters_cli_pytorch_cpu_timm_no_dynamic_axes( + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, + ): + self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, no_dynamic_axes=True) + @parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL, library_name="timm")) @require_torch_gpu @require_vision From d9515adf4ecb4a73f3b34aa5b1f23232bd4d79b6 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 12 Jan 2024 13:08:27 +0530 Subject: [PATCH 3/7] update doc --- docs/source/exporters/onnx/overview.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index a84d7308b2..24c2123466 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -121,6 +121,7 @@ Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index) - EfficientNet - EfficientNet (Knapsack Pruned) - Ensemble Adversarial Inception ResNet v2 +- ESE-VoVNet (Partial support with static shapes) - FBNet - (Gluon) Inception v3 - (Gluon) ResNet From b7866c28168c83ad0c22c73dd0ebbc94f4e05dbe Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 12 Jan 2024 13:12:14 +0530 Subject: [PATCH 4/7] update doc --- optimum/commands/export/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index 9405e38fef..1e87c77bf7 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -151,7 +151,7 @@ def parse_args_onnx(parser): ), ) optional_group.add_argument( - "--no_dynamic_axes", action="store_true", help="Disable dynamic axes during ONNX export" + "--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export" ) input_group = parser.add_argument_group( From bef9d0e01cfbb2259b398d432cdf5f96bbdbbecb Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 12 Jan 2024 13:18:24 +0530 Subject: [PATCH 5/7] update test --- tests/exporters/onnx/test_exporters_onnx_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 2847424c4e..66c04c2d78 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -265,9 +265,9 @@ def test_exporters_cli_pytorch_cpu_timm( @require_torch @require_vision @require_timm - # @slow + @slow @pytest.mark.timm_test - # @pytest.mark.run_slow + @pytest.mark.run_slow def test_exporters_cli_pytorch_cpu_timm_no_dynamic_axes( self, test_name: str, From b1af6f3b6d8fae4f76b1f7b0ca8b7c1ee28c2454 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Mon, 15 Jan 2024 21:30:42 +0530 Subject: [PATCH 6/7] add test for shape check --- optimum/exporters/onnx/__main__.py | 5 -- tests/exporters/exporters_utils.py | 13 +++ .../exporters/onnx/test_exporters_onnx_cli.py | 85 ++++++++++++++++++- 3 files changed, 96 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 96b048db1e..1da7f4e186 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -447,11 +447,6 @@ def main_export( f" `--task {task_non_past} --monolith`, or `--task {task}` without the monolith argument." ) - if library_name == "timm" and no_dynamic_axes is False and "vovnet" in model_type: - raise ValueError( - f"The export of {model_type} is not supported with dynamic axes. Please pass --no-dynamic-axes to export the model with fixed shapes" - ) - if original_task == "auto": synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) if synonyms_for_task: diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index cd850d8c95..d96719f6cb 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -28,6 +28,11 @@ "num_choices": [4], } +NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS = { + "batch_size": [1, 3, 5], + "num_choices": [2, 4], + "sequence_length": [8, 33, 96], +} PYTORCH_EXPORT_MODELS_TINY = { "albert": "hf-internal-testing/tiny-random-AlbertModel", @@ -327,6 +332,14 @@ } +PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES = { + "albert": "hf-internal-testing/tiny-random-AlbertModel", + "gpt2": "hf-internal-testing/tiny-random-gpt2", + "roberta": "hf-internal-testing/tiny-random-RobertaModel", + "roformer": "hf-internal-testing/tiny-random-RoFormerModel", +} + + PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES = { "default-timm-config": { "timm/ese_vovnet39b.ra_in1k": ["image-classification"], diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 66c04c2d78..397ecd16e7 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -33,18 +33,20 @@ ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME, ) -from optimum.utils.testing_utils import require_diffusers, require_sentence_transformers, require_timm +from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm if is_torch_available(): from optimum.exporters.tasks import TasksManager from ..exporters_utils import ( + NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, PYTORCH_EXPORT_MODELS_TINY, PYTORCH_SENTENCE_TRANSFORMERS_MODEL, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, + PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES, ) @@ -201,6 +203,48 @@ def _onnx_export( except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") + def _onnx_export_no_dynamic_axes( + self, + model_name: str, + task: str, + input_shape: dict, + input_shape_for_validation: tuple, + monolith: bool = False, + no_post_process: bool = False, + optimization_level: Optional[str] = None, + device: str = "cpu", + fp16: bool = False, + variant: str = "default", + model_kwargs: Optional[Dict] = None, + ): + with TemporaryDirectory() as tmpdir: + try: + main_export( + model_name_or_path=model_name, + output=tmpdir, + task=task, + device=device, + fp16=fp16, + optimize=optimization_level, + monolith=monolith, + no_post_process=no_post_process, + _variant=variant, + no_dynamic_axes=True, + model_kwargs=model_kwargs, + **input_shape, + ) + + model = onnx.load(Path(tmpdir) / "model.onnx") + + is_dynamic = any(dim.dim_param for dim in model.graph.input[0].type.tensor_type.shape.dim) + self.assertFalse(is_dynamic) + + model_input_shape = [dim.dim_value for dim in model.graph.input[0].type.tensor_type.shape.dim] + self.assertEqual(model_input_shape, input_shape_for_validation) + + except MinimumVersionError as e: + pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) @require_torch @require_vision @@ -278,7 +322,14 @@ def test_exporters_cli_pytorch_cpu_timm_no_dynamic_axes( monolith: bool, no_post_process: bool, ): - self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, no_dynamic_axes=True) + input_shapes_iterator = grid_parameters({"batch_size": [1, 3, 5]}, yield_dict=True, add_test_name=False) + for input_shape in input_shapes_iterator: + # NOTE: The timm models use input shapes from the model config, so we need to fix the other shapes of the model. + input_shape_for_validation = [input_shape["batch_size"], 3, 224, 224] + + self._onnx_export_no_dynamic_axes( + model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant + ) @parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL, library_name="timm")) @require_torch_gpu @@ -344,6 +395,36 @@ def test_exporters_cli_pytorch_cpu( self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, model_kwargs=model_kwargs) + @parameterized.expand(_get_models_to_test(PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES)) + @require_torch + @require_vision + def test_exporters_cli_pytorch_cpu_no_dynamic_axes( + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + variant: str, + monolith: bool, + no_post_process: bool, + ): + input_shapes_iterator = grid_parameters( + NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, yield_dict=True, add_test_name=False + ) + for input_shape in input_shapes_iterator: + if task == "multiple-choice": + input_shape_for_validation = [ + input_shape["batch_size"], + input_shape["num_choices"], + input_shape["sequence_length"], + ] + else: + input_shape_for_validation = [input_shape["batch_size"], input_shape["sequence_length"]] + + self._onnx_export_no_dynamic_axes( + model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant + ) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_vision @require_torch_gpu From b07dadce74be5c414b55f1d0d571ff26a5994f48 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 16 Jan 2024 17:40:42 +0530 Subject: [PATCH 7/7] fix style --- optimum/exporters/onnx/convert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index b29a92df5f..0b7fc5d90f 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -564,7 +564,6 @@ def remap(value): inputs = config.ordered_inputs(model) input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) - if no_dynamic_axes: dynamix_axes = None