Skip to content

Commit

Permalink
Add the ORTModelForSemanticSegmentation class (#539)
Browse files Browse the repository at this point in the history
* Initial commit for ORTModelForImageSegmentation

* Example name and io binding output size fix

* Refactor to ORTModelForSemanticSegmentation

* IOBindingHelper for io_binding in ORTModelForSemanticSegmentation

* Black and isort formatting

* Apply various suggestions from code review

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* Alphanumeric order in modeling_ort init

* Fixing docstring model name and removing comments

* Fixing tests and add import

* Black formatting and fixed unmatched quote

* Removing export_feature class attribute

* Fixing test model name

* Adding 'image-segmentation' task to optimum/pipelines and fixing tests

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
  • Loading branch information
TheoMrc and JingyaHuang committed Dec 16, 2022
1 parent f9feeca commit e8d9877
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 0 deletions.
2 changes: 2 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ORTModelForImageClassification",
"ORTModelForMultipleChoice",
"ORTModelForQuestionAnswering",
"ORTModelForSemanticSegmentation",
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
],
Expand Down Expand Up @@ -65,6 +66,7 @@
ORTModelForImageClassification,
ORTModelForMultipleChoice,
ORTModelForQuestionAnswering,
ORTModelForSemanticSegmentation,
ORTModelForSequenceClassification,
ORTModelForTokenClassification,
)
Expand Down
114 changes: 114 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AutoModelForImageClassification,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
Expand All @@ -37,6 +38,7 @@
ModelOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SemanticSegmenterOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
Expand Down Expand Up @@ -1541,6 +1543,118 @@ def forward(
return ImageClassifierOutput(logits=logits)


SEMANTIC_SEGMENTATION_EXAMPLE = r"""
Example of semantic segmentation:
```python
>>> import requests
>>> from PIL import Image
>>> from optimum.onnxruntime import {model_class}
>>> from transformers import {processor_class}
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> preprocessor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = preprocessor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
Example using `transformers.pipeline`:
```python
>>> import requests
>>> from PIL import Image
>>> from transformers import {processor_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> preprocessor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> onnx_image_segmenter = pipeline("image-segmentation", model=model, feature_extractor=preprocessor)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> pred = onnx_image_segmenter(url)
```
"""


@add_start_docstrings(
"""
Onnx Model with an all-MLP decode head on top e.g. for ADE20k, CityScapes.
""",
ONNX_MODEL_START_DOCSTRING,
)
class ORTModelForSemanticSegmentation(ORTModel):
"""
Semantic Segmentation model for ONNX.
"""

auto_model_class = AutoModelForSemanticSegmentation

def __init__(self, model=None, config=None, use_io_binding=True, **kwargs):
super().__init__(model, config, use_io_binding, **kwargs)
self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())}
self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())}
self.model_input_names = list(self.model_inputs.keys())
self.model_output_names = list(self.model_outputs.keys())

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ SEMANTIC_SEGMENTATION_EXAMPLE.format(
processor_class=_FEATURE_EXTRACTOR_FOR_DOC,
model_class="ORTModelForSemanticSegmentation",
checkpoint="optimum/segformer-b0-finetuned-ade-512-512",
)
)
def forward(self, **kwargs):
if self.device.type == "cuda" and self.use_io_binding:
io_binding = IOBindingHelper.prepare_io_binding(self, **kwargs)

# run inference with binding
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()

outputs = {}
for name, output in zip(self.model_output_names, io_binding._iobinding.get_outputs()):
outputs[name] = IOBindingHelper.to_pytorch(output)

# converts output to namedtuple for pipelines post-processing
return SemanticSegmenterOutput(logits=outputs["logits"])
else:
# converts pytorch inputs into numpy inputs for onnx
onnx_inputs = self._prepare_onnx_inputs(**kwargs)

# run inference
onnx_outputs = self.model.run(None, onnx_inputs)
outputs = self._prepare_onnx_outputs(onnx_outputs)

# converts output to namedtuple for pipelines post-processing
return SemanticSegmenterOutput(logits=outputs["logits"])

def _prepare_onnx_inputs(self, **kwargs):
model_inputs = {input_key.name: idx for idx, input_key in enumerate(self.model.get_inputs())}
onnx_inputs = {}
# converts pytorch inputs into numpy inputs for onnx
for input in model_inputs.keys():
onnx_inputs[input] = kwargs.pop(input).cpu().detach().numpy()

return onnx_inputs

def _prepare_onnx_outputs(self, onnx_outputs):
model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())}
outputs = {}
# converts onnxruntime outputs into tensor for standard outputs
for output, idx in model_outputs.items():
outputs[output] = torch.from_numpy(onnx_outputs[idx]).to(self.device)

return outputs


CUSTOM_TASKS_EXAMPLE = r"""
Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output):
Expand Down
8 changes: 8 additions & 0 deletions optimum/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AutomaticSpeechRecognitionPipeline,
FeatureExtractionPipeline,
ImageClassificationPipeline,
ImageSegmentationPipeline,
Pipeline,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Expand Down Expand Up @@ -50,6 +51,7 @@
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForQuestionAnswering,
ORTModelForSemanticSegmentation,
ORTModelForSeq2SeqLM,
ORTModelForSequenceClassification,
ORTModelForSpeechSeq2Seq,
Expand All @@ -70,6 +72,12 @@
"default": "google/vit-base-patch16-224",
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"class": (ORTModelForSemanticSegmentation,) if is_onnxruntime_available() else (),
"default": "nvidia/segformer-b0-finetuned-ade-512-512",
"type": "image",
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"class": (ORTModelForQuestionAnswering,),
Expand Down
113 changes: 113 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AutoModelForImageClassification,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
Expand All @@ -52,6 +53,7 @@
ORTModelForImageClassification,
ORTModelForMultipleChoice,
ORTModelForQuestionAnswering,
ORTModelForSemanticSegmentation,
ORTModelForSeq2SeqLM,
ORTModelForSequenceClassification,
ORTModelForSpeechSeq2Seq,
Expand Down Expand Up @@ -85,6 +87,7 @@
"bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"vit": "hf-internal-testing/tiny-random-vit",
"segformer": "hf-internal-testing/tiny-random-SegformerForSemanticSegmentation",
"whisper": "openai/whisper-tiny.en",
}

Expand Down Expand Up @@ -1282,6 +1285,116 @@ def test_compare_to_io_binding(self, *args, **kwargs):
gc.collect()


class ORTModelForSemanticSegmentationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = {
"segformer": "hf-internal-testing/tiny-random-SegformerForSemanticSegmentation",
}

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForSemanticSegmentation.from_pretrained(MODEL_NAMES["t5"], from_transformers=True)

self.assertIn("Unrecognized configuration class", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items())
def test_compare_to_transformers(self, *args, **kwargs):
model_arch, model_id = args
set_seed(SEED)
onnx_model = ORTModelForSemanticSegmentation.from_pretrained(model_id, from_transformers=True)

self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession)
self.assertIsInstance(onnx_model.config, PretrainedConfig)

set_seed(SEED)
trfs_model = AutoModelForSemanticSegmentation.from_pretrained(model_id)
preprocessor = get_preprocessor(model_id)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=image, return_tensors="pt")
onnx_outputs = onnx_model(**inputs)

self.assertTrue("logits" in onnx_outputs)
self.assertTrue(isinstance(onnx_outputs.logits, torch.Tensor))

with torch.no_grad():
trtfs_outputs = trfs_model(**inputs)

# compare tensor outputs
self.assertTrue(torch.allclose(onnx_outputs.logits, trtfs_outputs.logits, atol=1e-4))

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items())
def test_pipeline_ort_model(self, *args, **kwargs):
model_arch, model_id = args
onnx_model = ORTModelForSemanticSegmentation.from_pretrained(model_id, from_transformers=True)
preprocessor = get_preprocessor(model_id)
pipe = pipeline("image-segmentation", model=onnx_model, feature_extractor=preprocessor)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
outputs = pipe(url)

self.assertEqual(pipe.device, onnx_model.device)
self.assertTrue(outputs[0]["mask"] is not None)
self.assertTrue(isinstance(outputs[0]["label"], str))

gc.collect()

@pytest.mark.run_in_series
def test_pipeline_model_is_none(self):
pipe = pipeline("image-segmentation")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
outputs = pipe(url)
# compare model output class
self.assertTrue(outputs[0]["mask"] is not None)
self.assertTrue(isinstance(outputs[0]["label"], str))

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items())
@require_torch_gpu
def test_pipeline_on_gpu(self, *args, **kwargs):
model_arch, model_id = args
onnx_model = ORTModelForSemanticSegmentation.from_pretrained(model_id, from_transformers=True)
preprocessor = get_preprocessor(model_id)
pipe = pipeline("image-segmentation", model=onnx_model, feature_extractor=preprocessor, device=0)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
outputs = pipe(url)
# check model device
self.assertEqual(pipe.model.device.type.lower(), "cuda")

# compare model output class
self.assertTrue(outputs[0]["mask"] is not None)
self.assertTrue(isinstance(outputs[0]["label"], str))

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items())
@require_torch_gpu
def test_compare_to_io_binding(self, *args, **kwargs):
model_arch, model_id = args
set_seed(SEED)
onnx_model = ORTModelForSemanticSegmentation.from_pretrained(
model_id, from_transformers=True, use_io_binding=False
)
set_seed(SEED)
io_model = ORTModelForSemanticSegmentation.from_pretrained(
model_id, from_transformers=True, use_io_binding=True
)

preprocessor = get_preprocessor(model_id)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=[image] * 2, return_tensors="pt")
onnx_outputs = onnx_model(**inputs)
io_outputs = io_model(**inputs)

self.assertTrue("logits" in io_outputs)
self.assertIsInstance(io_outputs.logits, torch.Tensor)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs.logits, io_outputs.logits))

gc.collect()


class ORTModelForSeq2SeqLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
"t5",
Expand Down

0 comments on commit e8d9877

Please sign in to comment.