Skip to content

Commit

Permalink
Add support for markuplm ONNX export (#1784)
Browse files Browse the repository at this point in the history
* Add xpath dummy generator

* Add markuplm onnx config

* Update docs

* Add model to tests

* Get pad ids from normalized config

* Use hf-internal model

* Add markuplm to tiny exports

* Apply formatting
  • Loading branch information
pogzyb committed Apr 10, 2024
1 parent 69af5db commit 5ea14c1
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Llama
- M2-M100
- Marian
- MarkupLM
- MBart
- Mistral
- MobileBert
Expand Down
20 changes: 20 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
DummyXPathSeqInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -182,6 +183,25 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return common_inputs


class MarkupLMOnnxConfig(BertOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummyXPathSeqInputGenerator,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
xpath_dynamic_axis = {0: "batch_size", 1: "sequence_length", 2: "max_depth"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"token_type_ids": dynamic_axis,
"xpath_subs_seq": xpath_dynamic_axis,
"xpath_tags_seq": xpath_dynamic_axis,
}


class DebertaV2OnnxConfig(DebertaOnnxConfig):
pass

Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,13 @@ class TasksManager:
"text-generation-with-past",
onnx="MarianOnnxConfig",
),
"markuplm": supported_tasks_mapping(
"feature-extraction",
"text-classification",
"token-classification",
"question-answering",
onnx="MarkupLMOnnxConfig",
),
"mbart": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
DummyXPathSeqInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
Expand Down
52 changes: 52 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,58 @@ def generate(
return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype)


class DummyXPathSeqInputGenerator(DummyTextInputGenerator):
"""
Generates dummy xpath sequences.
"""

SUPPORTED_INPUT_NAMES = (
"xpath_tags_seq",
"xpath_subs_seq",
)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
padding_side: str = "right",
**kwargs,
):
super().__init__(
task,
normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
num_choices=num_choices,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
random_num_choices_range=random_num_choices_range,
padding_side=padding_side,
**kwargs,
)
self.max_depth = normalized_config.max_depth
self.tag_pad_id = normalized_config.tag_pad_id
self.subs_pad_id = normalized_config.subs_pad_id

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
min_value = 0
max_value = self.tag_pad_id if input_name == "xpath_tags_seq" else self.subs_pad_id
shape = [self.batch_size, self.sequence_length, self.max_depth]
return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype)


class DummyDecoderTextInputGenerator(DummyTextInputGenerator):
"""
Generates dummy decoder text inputs.
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class NormalizedConfigManager:
"llama": NormalizedTextConfigWithGQA,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"markuplm": NormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": NormalizedTextConfigWithGQA,
"mixtral": NormalizedTextConfigWithGQA,
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
Expand Down Expand Up @@ -237,6 +238,7 @@
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100", # Not using facebook/m2m100_418M because it takes too much time for testing.
"marian": "Helsinki-NLP/opus-mt-en-de",
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
"mbart": "sshleifer/tiny-mbart",
"mobilebert": "google/mobilebert-uncased",
# "mobilenet_v1": "google/mobilenet_v1_0.75_192",
Expand Down

0 comments on commit 5ea14c1

Please sign in to comment.