Skip to content

Commit

Permalink
Reactivate BART Onnx Export (#1666)
Browse files Browse the repository at this point in the history
* Reactivate BART Onnx Export

* Update tests/onnxruntime/test_modeling.py

* Update tests/onnxruntime/test_modeling.py

* Bart Reactivation - fxmarty review

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
  • Loading branch information
claeyzre and fxmarty committed Jan 30, 2024
1 parent 45c1c09 commit da6f9e2
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):

class BartOnnxConfig(M2M100OnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Bart now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
MIN_TORCH_VERSION = version.parse("2.1.2")
pass


Expand Down
6 changes: 2 additions & 4 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,8 @@ class TasksManager:
"text-generation-with-past",
"text2text-generation",
"text2text-generation-with-past",
# text-classification and question-answering can be supported, but the ONNX export is currently broken due to a regression in PyTorch 2.1.
# Reference: https://github.com/pytorch/pytorch/issues/110597.
# "text-classification",
# "question-answering",
"text-classification",
"question-answering",
onnx="BartOnnxConfig",
),
# BEiT cannot be used with the masked image modeling autoclass, so this task is excluded here
Expand Down
4 changes: 2 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def test_trust_remote_code(self):
class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"albert",
# "bart", # see tasks.py
"bart",
"bert",
# "big_bird",
# "bigbird_pegasus",
Expand Down Expand Up @@ -1592,7 +1592,7 @@ def test_compare_to_io_binding(self, model_arch):
class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"albert",
# "bart", # see tasks.py
"bart",
"bert",
# "big_bird",
# "bigbird_pegasus",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def tearDownClass(cls):
class ORTOptimizerTest(unittest.TestCase):
# Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing.
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = (
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bart"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bert"),
# (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-big_bird"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-distilbert"),
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ORTDynamicQuantizationTest(unittest.TestCase):
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 30),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-roberta", 30),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-distilbert", 30),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bart", 32),
)

SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
Expand Down

0 comments on commit da6f9e2

Please sign in to comment.