Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable export of model with fixed shape #1643

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Supported architectures from [🤗 Timm](https://huggingface.co/docs/timm/index)
- EfficientNet
- EfficientNet (Knapsack Pruned)
- Ensemble Adversarial Inception ResNet v2
- ESE-VoVNet (Partial support with static shapes)
- FBNet
- (Gluon) Inception v3
- (Gluon) ResNet
Expand Down
4 changes: 4 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def parse_args_onnx(parser):
"Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)
optional_group.add_argument(
"--no-dynamic-axes", action="store_true", help="Disable dynamic axes during ONNX export"
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -263,6 +266,7 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
legacy=self.args.legacy,
no_dynamic_axes=self.args.no_dynamic_axes,
model_kwargs=self.args.model_kwargs,
**input_shapes,
)
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def main_export(
_variant: str = "default",
library_name: Optional[str] = None,
legacy: bool = False,
no_dynamic_axes: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -270,6 +271,8 @@ def main_export(
The library of the model (`"transformers"` or `"diffusers"` or `"timm"` or `"sentence_transformers"`). If not provided, will attempt to automatically detect the library name for the checkpoint.
legacy (`bool`, defaults to `False`):
Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -556,6 +559,7 @@ def main_export(
input_shapes=input_shapes,
device=device,
dtype="fp16" if fp16 is True else None,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)

Expand Down
18 changes: 17 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def export_pytorch(
device: str = "cpu",
dtype: Optional["torch.dtype"] = None,
input_shapes: Optional[Dict] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -508,6 +509,8 @@ def export_pytorch(
Data type to remap the model inputs to. PyTorch-only. Only `torch.float16` is supported.
input_shapes (`Optional[Dict]`, defaults to `None`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -562,6 +565,11 @@ def remap(value):
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())

if no_dynamic_axes:
dynamix_axes = None
else:
dynamix_axes = dict(chain(inputs.items(), config.outputs.items()))

# Export can work with named args but the dict containing named args has to be the last element of the args
# tuple.
onnx_export(
Expand All @@ -570,7 +578,7 @@ def remap(value):
f=output.as_posix(),
input_names=input_names,
output_names=output_names,
dynamic_axes=dict(chain(inputs.items(), config.outputs.items())),
dynamic_axes=dynamix_axes,
do_constant_folding=True,
opset_version=opset,
)
Expand Down Expand Up @@ -694,6 +702,7 @@ def export_models(
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Expand All @@ -720,6 +729,8 @@ def export_models(
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -753,6 +764,7 @@ def export_models(
input_shapes=input_shapes,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
dtype=dtype,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)
)
Expand All @@ -770,6 +782,7 @@ def export(
input_shapes: Optional[Dict] = None,
disable_dynamic_axes_fix: Optional[bool] = False,
dtype: Optional[str] = None,
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[List[str], List[str]]:
"""
Expand All @@ -793,6 +806,8 @@ def export(
Whether to disable the default dynamic axes fixing.
dtype (`Optional[str]`, defaults to `None`):
Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported.
no_dynamic_axes (bool, defaults to `False`):
If True, disables the use of dynamic axes during ONNX export.
model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`):
Experimental usage: keyword arguments to pass to the model during
the export. This argument should be used along the `custom_onnx_config` argument
Expand Down Expand Up @@ -855,6 +870,7 @@ def export(
device=device,
input_shapes=input_shapes,
dtype=torch_dtype,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)

Expand Down
21 changes: 21 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
"num_choices": [4],
}

NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS = {
"batch_size": [1, 3, 5],
"num_choices": [2, 4],
"sequence_length": [8, 33, 96],
}

PYTORCH_EXPORT_MODELS_TINY = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
Expand Down Expand Up @@ -325,3 +330,19 @@
"sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1",
}


PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
}


PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES = {
"default-timm-config": {
"timm/ese_vovnet39b.ra_in1k": ["image-classification"],
"timm/ese_vovnet19b_dw.ra_in1k": ["image-classification"],
}
}
105 changes: 104 additions & 1 deletion tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
)
from optimum.utils.testing_utils import require_diffusers, require_sentence_transformers, require_timm
from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm


if is_torch_available():
from optimum.exporters.tasks import TasksManager

from ..exporters_utils import (
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS,
PYTORCH_EXPORT_MODELS_TINY,
PYTORCH_SENTENCE_TRANSFORMERS_MODEL,
PYTORCH_STABLE_DIFFUSION_MODEL,
PYTORCH_TIMM_MODEL,
PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES,
PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES,
)


Expand Down Expand Up @@ -179,6 +182,7 @@ def _onnx_export(
device: str = "cpu",
fp16: bool = False,
variant: str = "default",
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict] = None,
):
with TemporaryDirectory() as tmpdir:
Expand All @@ -193,11 +197,54 @@ def _onnx_export(
monolith=monolith,
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=no_dynamic_axes,
model_kwargs=model_kwargs,
)
except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

def _onnx_export_no_dynamic_axes(
self,
model_name: str,
task: str,
input_shape: dict,
input_shape_for_validation: tuple,
monolith: bool = False,
no_post_process: bool = False,
optimization_level: Optional[str] = None,
device: str = "cpu",
fp16: bool = False,
variant: str = "default",
model_kwargs: Optional[Dict] = None,
):
with TemporaryDirectory() as tmpdir:
try:
main_export(
model_name_or_path=model_name,
output=tmpdir,
task=task,
device=device,
fp16=fp16,
optimize=optimization_level,
monolith=monolith,
no_post_process=no_post_process,
_variant=variant,
no_dynamic_axes=True,
model_kwargs=model_kwargs,
**input_shape,
)

model = onnx.load(Path(tmpdir) / "model.onnx")

is_dynamic = any(dim.dim_param for dim in model.graph.input[0].type.tensor_type.shape.dim)
self.assertFalse(is_dynamic)

model_input_shape = [dim.dim_value for dim in model.graph.input[0].type.tensor_type.shape.dim]
self.assertEqual(model_input_shape, input_shape_for_validation)

except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

@parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items())
@require_torch
@require_vision
Expand Down Expand Up @@ -258,6 +305,32 @@ def test_exporters_cli_pytorch_cpu_timm(
):
self._onnx_export(model_name, task, monolith, no_post_process, variant=variant)

@parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL_NO_DYNAMIC_AXES, library_name="timm"))
@require_torch
@require_vision
@require_timm
@slow
@pytest.mark.timm_test
@pytest.mark.run_slow
def test_exporters_cli_pytorch_cpu_timm_no_dynamic_axes(
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
input_shapes_iterator = grid_parameters({"batch_size": [1, 3, 5]}, yield_dict=True, add_test_name=False)
for input_shape in input_shapes_iterator:
# NOTE: The timm models use input shapes from the model config, so we need to fix the other shapes of the model.
input_shape_for_validation = [input_shape["batch_size"], 3, 224, 224]

self._onnx_export_no_dynamic_axes(
model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant
)

@parameterized.expand(_get_models_to_test(PYTORCH_TIMM_MODEL, library_name="timm"))
@require_torch_gpu
@require_vision
Expand Down Expand Up @@ -322,6 +395,36 @@ def test_exporters_cli_pytorch_cpu(

self._onnx_export(model_name, task, monolith, no_post_process, variant=variant, model_kwargs=model_kwargs)

@parameterized.expand(_get_models_to_test(PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES))
@require_torch
@require_vision
def test_exporters_cli_pytorch_cpu_no_dynamic_axes(
self,
test_name: str,
model_type: str,
model_name: str,
task: str,
variant: str,
monolith: bool,
no_post_process: bool,
):
input_shapes_iterator = grid_parameters(
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS, yield_dict=True, add_test_name=False
)
for input_shape in input_shapes_iterator:
if task == "multiple-choice":
input_shape_for_validation = [
input_shape["batch_size"],
input_shape["num_choices"],
input_shape["sequence_length"],
]
else:
input_shape_for_validation = [input_shape["batch_size"], input_shape["sequence_length"]]

self._onnx_export_no_dynamic_axes(
model_name, task, input_shape, input_shape_for_validation, monolith, no_post_process, variant=variant
)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_vision
@require_torch_gpu
Expand Down