Skip to content

Commit

Permalink
ORTOptimizer for the model type Segformer (#1820)
Browse files Browse the repository at this point in the history
* add segformer

* black

* make format

* decoder_hidden_size not a list

* tests pass now

* use max

* use zero

---------

Co-authored-by: Zach Deane-Mayer <zach@ai-insight-solutions.com>
  • Loading branch information
zachmayer and Zach Deane-Mayer committed Jun 5, 2024
1 parent 7a0757a commit ac951ca
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
11 changes: 8 additions & 3 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,13 +1746,18 @@ class ORTModelForSemanticSegmentation(ORTModel):
checkpoint="optimum/segformer-b0-finetuned-ade-512-512",
)
)
def forward(self, **kwargs):
use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor)
def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
use_torch = isinstance(pixel_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if self.device.type == "cuda" and self.use_io_binding:
io_binding = IOBindingHelper.prepare_io_binding(
self,
pixel_values,
**kwargs,
ordered_input_names=self._ordered_input_names,
)
Expand All @@ -1769,7 +1774,7 @@ def forward(self, **kwargs):
# converts output to namedtuple for pipelines post-processing
return SemanticSegmenterOutput(logits=outputs["logits"])
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs)
onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, pixel_values=pixel_values, **kwargs)

# run inference
onnx_outputs = self.model.run(None, onnx_inputs)
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ORTConfigManager:
"nystromformer": "bert",
"pegasus": "bert",
"roberta": "bert",
"segformer": "vit",
"t5": "bert",
"vit": "vit",
"whisper": "bart",
Expand Down
15 changes: 14 additions & 1 deletion optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ class NormalizedVisionConfig(NormalizedConfig):
INPUT_SIZE = "input_size"


class NormalizedSegformerConfig(NormalizedVisionConfig):
NUM_ATTENTION_HEADS = "num_attention_heads"
HIDDEN_SIZE = "hidden_sizes"

# If the attribute is a list, return 0
# 0 means let the optimizer infer the correct value based on the model graph
def __getattr__(self, attr_name):
attr_value = super().__getattr__(attr_name)
if isinstance(attr_value, list):
attr_value = 0
return attr_value


class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig):
TEXT_CONFIG = None
VISION_CONFIG = None
Expand Down Expand Up @@ -203,7 +216,6 @@ class NormalizedConfigManager:
'owlvit',
'perceiver',
'roformer',
'segformer',
'squeezebert',
'table-transformer',
"""
Expand Down Expand Up @@ -258,6 +270,7 @@ class NormalizedConfigManager:
"regnet": NormalizedVisionConfig,
"resnet": NormalizedVisionConfig,
"roberta": NormalizedTextConfig,
"segformer": NormalizedSegformerConfig,
"speech-to-text": SpeechToTextLikeNormalizedTextConfig,
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
AutoOptimizationConfig,
ORTConfig,
ORTModelForImageClassification,
ORTModelForSemanticSegmentation,
ORTModelForSequenceClassification,
ORTOptimizer,
)
Expand Down Expand Up @@ -171,6 +172,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo

# Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing.
SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = (
(ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"),
(ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"),
)

Expand Down

0 comments on commit ac951ca

Please sign in to comment.