Skip to content

Commit

Permalink
add mt5 to optimization tests + model_feature argument
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Aug 10, 2022
1 parent f2bb4b8 commit d6b8632
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DistilBertForSequenceClassification,
ElectraForSequenceClassification,
GPT2ForSequenceClassification,
MT5ForConditionalGeneration,
RobertaForSequenceClassification,
XLMRobertaForSequenceClassification,
)
Expand All @@ -55,26 +56,27 @@ def test_save_and_load(self):
class ORTOptimizerTest(unittest.TestCase):

SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = (
(BertForSequenceClassification, "hf-internal-testing/tiny-random-bert"),
(DistilBertForSequenceClassification, "hf-internal-testing/tiny-random-distilbert"),
(BartForSequenceClassification, "hf-internal-testing/tiny-random-bart"),
(GPT2ForSequenceClassification, "hf-internal-testing/tiny-random-gpt2"),
(RobertaForSequenceClassification, "hf-internal-testing/tiny-random-roberta"),
(ElectraForSequenceClassification, "hf-internal-testing/tiny-random-electra"),
(XLMRobertaForSequenceClassification, "hf-internal-testing/tiny-xlm-roberta"),
(BigBirdForSequenceClassification, "hf-internal-testing/tiny-random-big_bird"),
(BartForSequenceClassification, "hf-internal-testing/tiny-random-bart", "sequence-classification"),
(BertForSequenceClassification, "hf-internal-testing/tiny-random-bert", "sequence-classification"),
(BigBirdForSequenceClassification, "hf-internal-testing/tiny-random-big_bird", "sequence-classification"),
(DistilBertForSequenceClassification, "hf-internal-testing/tiny-random-distilbert", "sequence-classification"),
(ElectraForSequenceClassification, "hf-internal-testing/tiny-random-electra", "sequence-classification"),
(GPT2ForSequenceClassification, "hf-internal-testing/tiny-random-gpt2", "sequence-classification"),
(MT5ForConditionalGeneration, "hf-internal-testing/tiny-random-mt5", "seq2seq-lm"),
(RobertaForSequenceClassification, "hf-internal-testing/tiny-random-roberta", "sequence-classification"),
(XLMRobertaForSequenceClassification, "hf-internal-testing/tiny-xlm-roberta", "sequence-classification"),
)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID)
def test_optimize(self, model_cls, model_name):
def test_optimize(self, model_cls, model_name, model_feature):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_cls(AutoConfig.from_pretrained(model_name))
optimization_config = OptimizationConfig(optimization_level=2, optimize_with_onnxruntime_only=False)
with tempfile.TemporaryDirectory() as tmp_dir:
output_dir = Path(tmp_dir)
model_path = output_dir.joinpath("model.onnx")
optimized_model_path = output_dir.joinpath("model-optimized.onnx")
optimizer = ORTOptimizer(tokenizer, model, feature="sequence-classification")
optimizer = ORTOptimizer(tokenizer, model, feature=model_feature)
optimizer.export(
onnx_model_path=model_path,
onnx_optimized_model_output_path=optimized_model_path,
Expand Down

0 comments on commit d6b8632

Please sign in to comment.