From 04dd8534262e2a523ac2830f3eab816ae465aab8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 10 Dec 2023 01:01:33 +0000 Subject: [PATCH 1/2] Add ESM onnx support --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 13 +++++++++++++ optimum/exporters/tasks.py | 7 +++++++ tests/exporters/exporters_utils.py | 1 + 4 files changed, 22 insertions(+) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 0a5da755a3..650b312367 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -42,6 +42,7 @@ Supported architectures: - Donut-Swin - Electra - Encoder Decoder +- ESM - Falcon - Flaubert - GPT-2 diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index a58f42dca4..cd62931f6e 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -183,6 +183,19 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig): pass +class EsmOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + ATOL_FOR_VALIDATION = 1e-4 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + } + + class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4d3f9f98d0..f1b15829bf 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -525,6 +525,13 @@ class TasksManager: "text2text-generation-with-past", onnx="EncoderDecoderOnnxConfig", ), + "esm": supported_tasks_mapping( + "feature-extraction", + "fill-mask", + "text-classification", + "token-classification", + onnx="EsmOnnxConfig", + ), "falcon": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 9af7806e7f..3ec8b86eb3 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -74,6 +74,7 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, + "esm": "hf-internal-testing/tiny-random-EsmModel", "falcon": { "fxmarty/really-tiny-falcon-testing": [ "feature-extraction", From a2cee7c6dd928f082d943ae15c09e779a99e3acf Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 10 Dec 2023 21:19:36 +0200 Subject: [PATCH 2/2] set default opset=12 --- optimum/exporters/onnx/model_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cd62931f6e..6c2d66c64c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -186,6 +186,7 @@ class DebertaV2OnnxConfig(DebertaOnnxConfig): class EsmOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig ATOL_FOR_VALIDATION = 1e-4 + DEFAULT_ONNX_OPSET = 12 @property def inputs(self) -> Dict[str, Dict[int, str]]: