Skip to content

Commit

Permalink
add mt5 to ORTConfigManager conf list (#341)
Browse files Browse the repository at this point in the history
* add MT5 to ORTOptimizer (after rebase and Optimizer changes)

* fix typo

* fix test for MT5 Seq2SeqLM model

* fix mt5 -> bart
  • Loading branch information
chainyo committed Oct 2, 2022
1 parent ceadc02 commit 5f7ac4d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class ORTConfigManager:
"electra": ("num_attention_heads", "hidden_size", "bert"),
"gpt2": ("n_head", "n_embd", "gpt2"),
"gpt_neo": ("num_heads", "hidden_size", "gpt2"),
"mt5": ("num_heads", "d_model", "bart"),
"marian": ("encoder_attention_heads", "d_model", "bart"),
"roberta": ("num_attention_heads", "hidden_size", "bert"),
"xlm-roberta": ("num_attention_heads", "hidden_size", "bert"),
Expand Down
3 changes: 2 additions & 1 deletion tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class ORTOptimizerTest(unittest.TestCase):

# Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing.
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = (
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bert"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bart"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bert"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-big_bird"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-distilbert"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-electra"),
Expand Down Expand Up @@ -77,6 +77,7 @@ def test_compare_original_model_with_optimized_model(self, model_cls, model_name
(ORTModelForSeq2SeqLM, "hf-internal-testing/tiny-random-bart", True),
(ORTModelForSeq2SeqLM, "hf-internal-testing/tiny-random-marian", False),
(ORTModelForSeq2SeqLM, "hf-internal-testing/tiny-random-marian", True),
(ORTModelForSeq2SeqLM, "hf-internal-testing/tiny-random-onnx-mt5", False),
)

@parameterized.expand(SUPPORTED_SEQ2SEQ_ARCHITECTURES_WITH_MODEL_ID)
Expand Down

0 comments on commit 5f7ac4d

Please sign in to comment.