Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to use GPU accelerated Optimum Onnx transformer model for inference #580

Closed
2 of 4 tasks
smiraldr opened this issue Dec 13, 2022 · 11 comments · Fixed by #602
Closed
2 of 4 tasks

Unable to use GPU accelerated Optimum Onnx transformer model for inference #580

smiraldr opened this issue Dec 13, 2022 · 11 comments · Fixed by #602
Labels
bug Something isn't working

Comments

@smiraldr
Copy link
Contributor

System Info

Optimum Version: 1.5.0
Ubuntu 20.04 Linux 
Python version 3.8

Who can help?

@JingyaHuang @echarlaix
When following the documentation on https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/gpu for 1.5.0 version optimum. We get the following error:


RuntimeError Traceback (most recent call last)
in
19 "education",
20 "music"]
---> 21 pred = onnx_z0(sequence_to_classify, candidate_labels, multi_class=False)

8 frames
/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py in bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr)
454 :param buffer_ptr: memory pointer to input data
455 """
--> 456 self._iobinding.bind_input(
457 name,
458 C.OrtDevice(

RuntimeError: Error when binding input: There's no data transfer registered for copying tensors from Device:[DeviceType:1 MemoryType:0 DeviceId:0] to Device:[DeviceType:0 MemoryType:0 DeviceId:0]

This is reproducible on google colab gpu instance as well. This is observed from 1.5.0 version only and 1.4.1 works as expected.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

!pip install optimum[onnxruntime-gpu]==1.5.1
!pip install transformers onnx

from optimum.onnxruntime import ORTModelForSequenceClassification

ort_model = ORTModelForSequenceClassification.from_pretrained(
"philschmid/tiny-bert-sst2-distilled",
from_transformers=True,
provider="CUDAExecutionProvider",
)

from optimum.pipelines import pipeline
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")

pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer)
result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
print(result)

Expected behavior

Inference fails due to device error, which is not expected.

@smiraldr smiraldr added the bug Something isn't working label Dec 13, 2022
@fxmarty
Copy link
Collaborator

fxmarty commented Dec 13, 2022

Thanks for the report!

The ONNX Runtime pipeline follow the same schema as transformers: https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.TextClassificationPipeline

So you need to pass device="cuda:0" at the pipeline initialization time for it to work.

But I agree the error message is not ideal, we can fix that! I'll add an example with pipelines as well in the guide you linked.

@smiraldr
Copy link
Contributor Author

How am I able to still run the same code without device=0 with 1.4.1 ? Is there something wrong what I'm doing here?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 13, 2022

The full script, working well with optimum-1.5.1, is:

from optimum.onnxruntime import ORTModelForSequenceClassification

ort_model = ORTModelForSequenceClassification.from_pretrained(
    "philschmid/tiny-bert-sst2-distilled",
    from_transformers=True,
    provider="CUDAExecutionProvider",
)

from optimum.pipelines import pipeline
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")

pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
print(result)

It works as well on 1.4.1, but I would advise to update.

@smiraldr
Copy link
Contributor Author

Got it. Would be helpful if you could update the docs on optimum gpu to reflect the code above and change the error prompt. Feel free to close the issue. Thanks for the quick resolution.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 13, 2022

Will do, thanks for reporting this lack in the doc! We are open to contributions as well!

@smiraldr
Copy link
Contributor Author

Oh, I would love to take this up and contribute to docs and suggest a better error invocation. Will take this up if that's fine ?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 13, 2022

For sure, thanks a lot! Don't hesitate if you need any guidance!

@smiraldr
Copy link
Contributor Author

@fxmarty could you help me with where exactly do i need to handle the error in code for raising a better error - Here is the stacktrace

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-8-4d3f3576cdc8>](https://localhost:8080/#) in <module>
     5 
     6 pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer)
----> 7 result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
     8 print(result)

8 frames
[/usr/local/lib/python3.8/dist-packages/transformers/pipelines/text_classification.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
   153             If `top_k` is used, one such dictionary is returned per label.
   154         """
--> 155         result = super().__call__(*args, **kwargs)
   156         # TODO try and retrieve it in a nicer way from _sanitize_parameters.
   157         _legacy = "top_k" not in kwargs

[/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in __call__(self, inputs, num_workers, batch_size, *args, **kwargs)
  1072             return self.iterate(inputs, preprocess_params, forward_params, postprocess_params)
  1073         else:
-> 1074             return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  1075 
  1076     def run_multi(self, inputs, preprocess_params, forward_params, postprocess_params):

[/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
  1079     def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
  1080         model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1081         model_outputs = self.forward(model_inputs, **forward_params)
  1082         outputs = self.postprocess(model_outputs, **postprocess_params)
  1083         return outputs

[/usr/local/lib/python3.8/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in forward(self, model_inputs, **forward_params)
   988                 with inference_context():
   989                     model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
--> 990                     model_outputs = self._forward(model_inputs, **forward_params)
   991                     model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
   992             else:

[/usr/local/lib/python3.8/dist-packages/transformers/pipelines/text_classification.py](https://localhost:8080/#) in _forward(self, model_inputs)
   180 
   181     def _forward(self, model_inputs):
--> 182         return self.model(**model_inputs)
   183 
   184     def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=True):

[/usr/local/lib/python3.8/dist-packages/optimum/modeling_base.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    58 
    59     def __call__(self, *args, **kwargs):
---> 60         return self.forward(*args, **kwargs)
    61 
    62     @abstractmethod

[/usr/local/lib/python3.8/dist-packages/optimum/onnxruntime/modeling_ort.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, token_type_ids, **kwargs)
   934     ):
   935         if self.device.type == "cuda" and self.use_io_binding:
--> 936             io_binding, output_shapes, output_buffers = self.prepare_io_binding(
   937                 input_ids, attention_mask, token_type_ids
   938             )

[/usr/local/lib/python3.8/dist-packages/optimum/onnxruntime/modeling_ort.py](https://localhost:8080/#) in prepare_io_binding(self, input_ids, attention_mask, token_type_ids)
   869         # bind input ids
   870         input_ids = input_ids.contiguous()
--> 871         io_binding.bind_input(
   872             "input_ids",
   873             input_ids.device.type,

[/usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py](https://localhost:8080/#) in bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr)
   456         self._iobinding.bind_input(
   457             name,
--> 458             C.OrtDevice(
   459                 get_ort_device_type(device_type, device_id),
   460                 C.OrtDevice.default_memory(),

TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
   1. onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice(arg0: int, arg1: int, arg2: int)

Invoked with: 0, 0, None

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 19, 2022

Thank you for your PR, it's great!

I think there could be a check here that the tensors are on the right device:

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
input_ids, attention_mask, token_type_ids

Although I am not sure it's worth it introducing more checks directly in ORTModel's, what do you think @JingyaHuang ?

Alternatively, it could be possible to raise an error in

def pipeline(
if a device is not passed to the pipeline.

@JingyaHuang
Copy link
Collaborator

@fxmarty

  • In the design, I prefer that we set up a default device index(0) when parsing provider to a device, to not interrupt the execution. If users want it for a specific device, they can set it up themselves by indicating explicitly.
  • And we can mention in the doc what users should do in the doc if she/he want to use a specific device.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 20, 2022

@smiraldr So as I understand in fact it was a device indexing issue, @JingyaHuang fixed it in #613 . So your PR looks good as is, moving the discussion there!

@fxmarty fxmarty closed this as completed Dec 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants