Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ORTModelForSemanticSegmentation class #539

Merged

Conversation

TheoMrc
Copy link
Contributor

@TheoMrc TheoMrc commented Dec 2, 2022

What does this PR do?

This PR aims to implement the ORTModelForImageSegmentation class to provide support for image segmentation .onnx models, and full integration of such models through transformers pipelines for CPU or GPU onnxruntime inference (see Issue #382)

Implementation details

The ORTModelForImageSegmentation was based on the already implemented ORTModelForImageClassification in optimum/onnxruntime/modeling_ort.py with several modifications:

  1. For CPU and GPU inference :
  • class was added to optimum/onnxruntime/__init__.py
  • self.forward method returns a SemanticSegmenterOutput instead of ImageClassifierOutput
  • correct auto_model_class and export_feature referenced
  • Copied all tests from the ORTModelForImageClassificationIntegrationTest in tests/onnxruntime/test_modeling.py
  1. For GPU inference
  • logits_shape was changed ORTModelForImageSegmentation.prepare_logits_buffer to return a 4 dimensional tensor shape 2D of shape (input_batch_size, self.config.num_labels, output_height, output_width).
    The issue is that I did not find a way to get model output size, which is different from input size from config.json, or any other attribute of ORTModelForImageSegmentation or ORTModelForImageSegmentation.model.

CPU inference works as following:

from optimum.onnxruntime.modeling_ort import ORTModelForImageSegmentation
session = ORTModelForImageSegmentation.load_model(onnx_path)
onnx_model = ORTModelForImageSegmentation(session)
inputs = feature_extractor(pil_image, return_tensors="pt")
outputs = onnx_model(**inputs)

I could not test GPU inference because I could not manage to make onnxruntime-gpu work:

onnx_model.to('cuda:0')
>>>  File "C:\Users\theol\Documents\GitHub\Repositories\optimum\optimum\onnxruntime\modeling_ort.py", line 202, in to
    validate_provider_availability(provider)  # raise error if the provider is not available
>>>  File "C:\Users\theol\Documents\GitHub\Repositories\optimum\optimum\onnxruntime\utils.py", line 227, in validate_provider_availability
    raise ImportError(
>>>ImportError: Asked to use CUDAExecutionProvider, but `onnxruntime-gpu` package was not found. Make sure to install `onnxruntime-gpu` package instead of `onnxruntime`.

Might be because of local venv setup issues on my side.
My CUDA installation is working for transformers with torch models.
Still, it probably would not work properly yet because of the wrong output size in prepare_logits_buffer

Remaining tasks

  • Fixing proper output size for io binding
  • Uploading a .onnx segmentation model to https://huggingface.co/hf-internal-testing and modify IMAGE_SEGMENTATION_EXAMPLE checkpoint name and image url to appropriate example. (See two comments at optimum/onnxruntime/modeling_ort.py lines 1463 and 1533)
  • Modify test class model to a SemanticSegmentation model in order to get working tests

@michaelbenayoun @JingyaHuang your help would be appreciated 👍

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 4, 2022

The documentation is not available anymore as the PR was closed or merged.

@JingyaHuang
Copy link
Collaborator

Hi @TheoMrc,

Thanks for opening the PR, it looks great!!

Just some questions for other team members before jumping into a more detailed review:

  • @michaelbenayoun, why did you suggest ORTModelForImageSegmentation as the class name instead of ORTModelForSemanticSegmentation? I feel that would be better to keep the consistency with AutoModel in transformers. WDYT?
  • And gently tagging @NielsRogge, for Segformer model, would it be possible to infer the output height and width in logits by the shape of pixel_values and the model's config?

(@TheoMrc for the IOBinding, ONNX Runtime supports binding with OrtValue which doesn't need the exact shape of outputs, check #447. For the moment, Optimum prefers direct binding with a Torch tensor to avoid some potential overhead, let's wait for the insight from Niels and see if it would be necessary to add custom output shape support for IOBinding.)

And for your question on the GPU test, if your CUDA env is correctly set, ensure that you have uninstalled onnxruntime before installing onnxruntime-gpu. More detail here.

@michaelbenayoun
Copy link
Member

I think the class it is supporting is AutoModelForImageSegmentation.

@NielsRogge
Copy link

Hi,

We'll deprecate the xxxForImageSegmentation class in Transformers, as we now split them up into xxxForInstanceSegmentation, xxxForSemanticSegmentation and xxxForPanopticSegmentation.

The pipeline however is just called "image segmentation", as it unifies all 3 into a single abstraction.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work @TheoMrc !

Should wait for @NielsRogge and @JingyaHuang approvals as well.

@@ -1277,6 +1278,117 @@ def test_compare_to_io_binding(self, *args, **kwargs):
gc.collect()


class ORTModelForImageSegmentationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = {
"vit": "hf-internal-testing/tiny-random-vit", # Probably have to modify to an onnx segmentation model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to modify

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm
I'm wondering why though ?? Tests are dying on my pc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VIT doesn't have segmentation model implemented, let's replace it by Segformer

@@ -1277,6 +1278,117 @@ def test_compare_to_io_binding(self, *args, **kwargs):
gc.collect()


class ORTModelForImageSegmentationIntegrationTest(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be similar to ORTModelForImageClassification, maybe we could abstract all of these test with a base test class, can be done in another PR

@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 5, 2022

Thanks for your feedbacks !

@michaelbenayoun

I think the class it is supporting is AutoModelForImageSegmentation.

Actually, last time I tried, loading my segformer with AutoModel only worked with AutoModelForSemanticSegmentation and not with AutoModelForImageSegmentation. I don't know how relevant this is. That is why I put AutoModelForSemanticSegmentation as auto_model_class in ORTModelForImageSegmentation.
Should we refactor to ORTModelForSemanticSegmentation ?

Seems weird to me that there are two classes, that are supposed to do the same thing but don't behave the same, maybe XXXForImageSegmentation is for binary outputs ?
Anyway interesting to hear that XXXForImageSegmentation will be deprecated.

@NielsRogge
Copy link

Should we refactor to ORTModelForSemanticSegmentation ?

I think yes, to align with Transformers.

@TheoMrc TheoMrc changed the title Add the ORTModelForImageSegmentation class Add the ORTModelForSemanticSegmentation class Dec 5, 2022
@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 5, 2022

I believe we should probably change the model name to a segmentation model in the test class.

And I will remove comments from the code soon

@JingyaHuang
Copy link
Collaborator

Seeing this for the output shape, it seems that the output logits' shape would be (batch_size, config.num_labels, pixel_values.size(2) //4 , pixel_values.size(3) //4), is that so @NielsRogge?

@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 5, 2022

@JingyaHuang
I think that this might be the case for some semantic segmentation models e.g. SegformerForSemanticSegmentation ; but not all segmentation model types loadable with AutoModelForSemanticSegmentation which we aim to support with ORTModelForSemanticSegmentation !

@JingyaHuang
Copy link
Collaborator

Hi @TheoMrc,

I just merged the support for dynamic shape output IOBinding. Can you update the IOBinding part in your PR like the way done in ORTModelForCustomTasks?

def prepare_io_binding(self, **kwargs) -> ort.IOBinding:

And maybe try to fix the style with

make style

to pass the code quality check, thank you!

@NielsRogge
Copy link

Hi @JingyaHuang, yes that's correct.

@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 9, 2022

Hi @JingyaHuang,

Gonna do it soon. To clarify what I should do, you're asking me to copy-paste all functions linked to io_binding from ORTModelForCustomTasks ?
I'm guessing your IObinding implementation in ORTModelForCustomTasks handles any output shape, so the only difference I can see is outputing the correct SemanticSegmenterOutput.

Shouldn't we just make ORTForSemanticSegmentation inherit from ORTModelForCustomTasks instead of ORTModel ? We should then also add a class attribute which stores in cls.OutputClass = SemanticSegmenterOutput.
Feels a bit redundant to me. Maybe in another PR ?

And maybe try to fix the style with
make style

Not sure how to do such a thing, this looks like a linux command (although I couldn't find anything about it)
I'm currently on Windows but have Cygwin64 (so I have installed the make command), could you provide me with a link towards any documentation ?

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Dec 9, 2022

Hi @TheoMrc,

It is a great point to avoid copying the method for preparing dynamic shape IOBinding. For structure consistency, what I suggest to do, is to move the prepare_io_binding method to IOBindingHelper instead of inheriting ORTModelForCustomTasks.

I will open a PR now to make the move. And for the styling, alternatively, you can pip install black and isort to do the formatting.

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Dec 12, 2022

Hi @TheoMrc, the dynamic IOBinding preparation is moved to the IOBindingHelper now:

def prepare_io_binding(ort_model: "ORTModel", **inputs) -> ort.IOBinding:

You can use it for the I/O binding of ORTModelForSemanticSegmentation, feel free to ping me if you need help.

@TheoMrc TheoMrc force-pushed the ORTModelForImageSegmentation-implementation branch from e19918b to ff63f5a Compare December 13, 2022 20:04
@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 13, 2022

Hi again @JingyaHuang !

  • I merged main and upstream/main branches into my branch to get the changes you made to IOBindingHelper available in this branch.

  • Then, I basically tried to do the exact same thing you did in the forward method for ORTModelForCustomTasks while outputing SemanticSegmenterOutput(logits=outputs["logits"])

Once again, despite a working CUDA v11.7 env for torch, I did not manage to make onnxruntime-gpu work properly so I could only test inference on CPU as following:

from optimum.onnxruntime.modeling_ort import ORTModelForSemanticSegmentation
from transformers import SegformerFeatureExtractor
from PIL import Image

session = ORTModelForSemanticSegmentation.load_model(onnx_path)
model = ORTModelForSemanticSegmentation(session)
feature_extractor = SegformerFeatureExtractor()

image = Image.open(img_path)
inputs = feature_extractor(image , return_tensors="pt")
out = model(**inputs)

Additional information about GPU ORT not working:
Any model.to('cuda:0') throws error:

..\..\optimum\onnxruntime\modeling_ort.py:243: in to
>>> validate_provider_availability(provider)  # raise error if the provider is not available

E ImportError: Asked to use CUDAExecutionProvider, but `onnxruntime-gpu` package was not found. Make sure to install `onnxruntime-gpu` package instead of `onnxruntime`.

I think this has to do with the fact that I don't have a typical optimum installation but just cloned the branch locally, installed dependencies with pycharm and venv, which do weird stuff from time to time.

  • Applied black and isort formatting following your advice (thanks for the tip btw, will definetely use in other projects 😀)

Obviously, we should test GPU inference.
I think we should change the SUPPORTED_ARCHITECTURES_WITH_MODEL_ID in ORTModelForSemanticSegmentationIntegrationTest to pass tests in workflows and locally: these tests do not work properly, despite the fact that they work fine for ORTModelForImageClassificationIntegrationTest (see below)

E       ValueError: Unrecognized configuration class <class 'transformers.models.vit.configuration_vit.ViTConfig'> for this kind of AutoModel: AutoModelForSemanticSegmentation.
E       Model type should be one of BeitConfig, Data2VecVisionConfig, DPTConfig, MobileNetV2Config, MobileViTConfig, SegformerConfig.

What do you think is left to do ?

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @TheoMrc, thanks for the update.

Just left some nits.

optimum/onnxruntime/__init__.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
@@ -1277,6 +1278,117 @@ def test_compare_to_io_binding(self, *args, **kwargs):
gc.collect()


class ORTModelForImageSegmentationIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = {
"vit": "hf-internal-testing/tiny-random-vit", # Probably have to modify to an onnx segmentation model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VIT doesn't have segmentation model implemented, let's replace it by Segformer

optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
tests/onnxruntime/test_modeling.py Outdated Show resolved Hide resolved
tests/onnxruntime/test_modeling.py Outdated Show resolved Hide resolved
tests/onnxruntime/test_modeling.py Outdated Show resolved Hide resolved
@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Dec 14, 2022

@TheoMrc, thanks for updating the code!

I just tested on my end with the following snippet:

import torch
from transformers import SegformerImageProcessor,  pipeline
from PIL import Image
import requests

from optimum.onnxruntime import ORTModelForSemanticSegmentation

device = torch.device("cuda:0")

image_processor = SegformerImageProcessor.from_pretrained("optimum/segformer-b0-finetuned-ade-512-512")
model = ORTModelForSemanticSegmentation.from_pretrained("optimum/segformer-b0-finetuned-ade-512-512")
model.to(device)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(images=image, return_tensors="pt").to(device)

out = model(**inputs)

pipe = pipeline("image-segmentation", model=model, feature_extractor=image_processor)
pred = pipe(url)

It is working decently both on CPU and on GPU.

For failing tests:

  • ONNX Runtime tests
    You used VIT in the test, however in transformers, VIT doesn't have semantic segmentation model implemented, that's why all tests using AutoModelForSemanticSegmentation failed. I suggest using Segformer instead.
  • Exporter tests don't seem to be relevant to this PR, just re-launched the CIs, let's see how it goes.

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
@TheoMrc
Copy link
Contributor Author

TheoMrc commented Dec 15, 2022

Hi @JingyaHuang,

Thanks for the code review !

I believe I did all the requested changes.
It was a bit tricky to understand why tests were still crashing on my end. Turns out there is a pipeline func in optimum as well, that the tests run on a vanilla torch model instead of an ONNX model, that ORTModel can load vanilla torch models, ... 🥲
So, during my investingations, I had to do multiple commits, hope it's okay !

I had to add 'image-segmentation' to the SUPPORTED_TASKS dict in optimum/pipelines.py and change test asserts from outputs[0]['score'] >= 0 checks (that was copied from ImageClassification tests) to outputs[0]['mask'] is not None checks.
Maybe you will think of a cleaner assertion for image segmentation !

All tests are running smooth locally (appart from the GPU one). Hopefully everything will be good with the next PR run :)

I'm glad I could contribute and learn a bunch of things on the way !
Do not hesitate if you think there's anything else to fix

See you,

Theo

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @TheoMrc, the PR looks great, thanks for iterating and contributing to Optimum!

The exporter currently has some breaking changes, the tests for the exporter are irrelevant to this PR, we can merge it after all other tests pass.

@JingyaHuang JingyaHuang merged commit e8d9877 into huggingface:main Dec 16, 2022
@TheoMrc TheoMrc deleted the ORTModelForImageSegmentation-implementation branch December 16, 2022 15:51
@allankouidri
Copy link

Hi @TheoMrc, thank you for your contribution!

I am running into the same issue when running GPU inference.

ImportError: Asked to use CUDAExecutionProvider, but onnxruntime-gpu package was not found. Make sure to install onnxruntime-gpu package instead of onnxruntime.

I made sure that:

  • The versions of CUDA, cuDNN and onnxruntime-gpu are compatible according to the recommendation table
  • Set up the environment path
  • onnxruntime was not installed

Did you manage to make it work on GPU?
Thank you

System information

  • Windows 11
  • ONNX Runtime (gpu) version: 1.12.0
  • ONNX version: 1.12 (tried 1.13 also)
  • Optimum version: 1.6.1
  • Python version: tried with 3.8/ 3.9/ 3.10
  • CUDA/cuDNN version: CUDA 11.4 cuDNN 8.2.2

@TheoMrc
Copy link
Contributor Author

TheoMrc commented Jan 28, 2023

Hi @allankouidri,

I did not get to try setting up a clean venv to try GPU inference with ONNX. Since it works fine for @JingyaHuang, I think maybe it could be a matter of OS, since we are both on windows ? Our issue probably has to do with the method used to detect onnxruntime vs onnxruntime-gpu.

I'll try a few things when I have the time and come back to you

@michaelbenayoun
Copy link
Member

Hi @allankouidri,
One possibility is that you have both onnxruntime and onnxruntime-gpu installed.

Try doing this:

pip uninstall onnxruntime
pip uninstall onnxruntime-gpu
pip install onnxruntime-gpu

@fxmarty fxmarty mentioned this pull request Feb 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants