From ac951ca0ae49983b1515b84342fea15c3b7ec35c Mon Sep 17 00:00:00 2001 From: Zach Deane-Mayer <581590+zachmayer@users.noreply.github.com> Date: Wed, 5 Jun 2024 02:18:42 -0400 Subject: [PATCH] ORTOptimizer for the model type Segformer (#1820) * add segformer * black * make format * decoder_hidden_size not a list * tests pass now * use max * use zero --------- Co-authored-by: Zach Deane-Mayer --- optimum/onnxruntime/modeling_ort.py | 11 ++++++++--- optimum/onnxruntime/utils.py | 1 + optimum/utils/normalized_config.py | 15 ++++++++++++++- tests/onnxruntime/test_optimization.py | 2 ++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index eb38a7fef1..b65e1d3b29 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1746,13 +1746,18 @@ class ORTModelForSemanticSegmentation(ORTModel): checkpoint="optimum/segformer-b0-finetuned-ade-512-512", ) ) - def forward(self, **kwargs): - use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor) + def forward( + self, + pixel_values: Union[torch.Tensor, np.ndarray], + **kwargs, + ): + use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: io_binding = IOBindingHelper.prepare_io_binding( self, + pixel_values, **kwargs, ordered_input_names=self._ordered_input_names, ) @@ -1769,7 +1774,7 @@ def forward(self, **kwargs): # converts output to namedtuple for pipelines post-processing return SemanticSegmenterOutput(logits=outputs["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs) + onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, pixel_values=pixel_values, **kwargs) # run inference onnx_outputs = self.model.run(None, onnx_inputs) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 0e1da447a6..37d0feefcc 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -128,6 +128,7 @@ class ORTConfigManager: "nystromformer": "bert", "pegasus": "bert", "roberta": "bert", + "segformer": "vit", "t5": "bert", "vit": "vit", "whisper": "bart", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 682f70e3ca..81207b7649 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -102,6 +102,19 @@ class NormalizedVisionConfig(NormalizedConfig): INPUT_SIZE = "input_size" +class NormalizedSegformerConfig(NormalizedVisionConfig): + NUM_ATTENTION_HEADS = "num_attention_heads" + HIDDEN_SIZE = "hidden_sizes" + + # If the attribute is a list, return 0 + # 0 means let the optimizer infer the correct value based on the model graph + def __getattr__(self, attr_name): + attr_value = super().__getattr__(attr_name) + if isinstance(attr_value, list): + attr_value = 0 + return attr_value + + class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig): TEXT_CONFIG = None VISION_CONFIG = None @@ -203,7 +216,6 @@ class NormalizedConfigManager: 'owlvit', 'perceiver', 'roformer', - 'segformer', 'squeezebert', 'table-transformer', """ @@ -258,6 +270,7 @@ class NormalizedConfigManager: "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, + "segformer": NormalizedSegformerConfig, "speech-to-text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index c9cadbaa82..82109fcd11 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -36,6 +36,7 @@ AutoOptimizationConfig, ORTConfig, ORTModelForImageClassification, + ORTModelForSemanticSegmentation, ORTModelForSequenceClassification, ORTOptimizer, ) @@ -171,6 +172,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo # Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing. SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = ( + (ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"), (ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"), )