diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 0a5da755a3..650b312367 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -42,6 +42,7 @@ Supported architectures: - Donut-Swin - Electra - Encoder Decoder +- ESM - Falcon - Flaubert - GPT-2 diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 443aa6fc70..d4885d0eba 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -183,6 +183,20 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig): pass +class EsmOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + ATOL_FOR_VALIDATION = 1e-4 + DEFAULT_ONNX_OPSET = 12 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + } + + class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 6abb9e7412..96d68cd812 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -525,6 +525,13 @@ class TasksManager: "text2text-generation-with-past", onnx="EncoderDecoderOnnxConfig", ), + "esm": supported_tasks_mapping( + "feature-extraction", + "fill-mask", + "text-classification", + "token-classification", + onnx="EsmOnnxConfig", + ), "falcon": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9af7806e7f..3ec8b86eb3 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -74,6 +74,7 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, + "esm": "hf-internal-testing/tiny-random-EsmModel", "falcon": { "fxmarty/really-tiny-falcon-testing": [ "feature-extraction",