Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exporting ORTModelForVision2Seq doesn't work correctly on Pytorch 1.11 #1909

Open
2 of 4 tasks
Yosshi999 opened this issue Jun 15, 2024 · 4 comments
Open
2 of 4 tasks
Labels
bug Something isn't working

Comments

@Yosshi999
Copy link

System Info

* optimum==1.20.0
* on docker image `nvidia/cuda:11.7.0-devel-ubuntu20.04`
* torch==1.11.0+cpu
* transformers==4.41.2
* Python 3.8.10

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

Install Pytorch 1.11 (on >= 1.12 this bug is resolved by pytorch)

and run https://huggingface.co/docs/optimum/v1.20.0/en/onnxruntime/package_reference/modeling_ort#optimum.onnxruntime.ORTModelForVision2Seq.forward.example

from transformers import AutoImageProcessor, AutoTokenizer
from optimum.onnxruntime import ORTModelForVision2Seq
from PIL import Image
import requests


processor = AutoImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
model = ORTModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning", export=True)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(image, return_tensors="pt")

gen_tokens = model.generate(**inputs)
outputs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)

It raises the error:

Traceback (most recent call last):
  File "convert.py", line 9, in <module>
    model = ORTModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning", export=True)
  File "/usr/local/lib/python3.8/dist-packages/optimum/onnxruntime/modeling_ort.py", line 669, in from_pretrained
    return super().from_pretrained(
  File "/usr/local/lib/python3.8/dist-packages/optimum/modeling_base.py", line 402, in from_pretrained
    return from_pretrained_method(
  File "/usr/local/lib/python3.8/dist-packages/optimum/onnxruntime/modeling_seq2seq.py", line 1056, in _from_transformers
    main_export(
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/__main__.py", line 352, in main_export
    onnx_export_from_model(
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/convert.py", line 1170, in onnx_export_from_model
    _, onnx_outputs = export_models(
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/convert.py", line 776, in export_models
    export(
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/convert.py", line 881, in export
    export_output = export_pytorch(
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/convert.py", line 577, in export_pytorch
    onnx_export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 720, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 500, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 441, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 392, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/optimum/exporters/onnx/model_patcher.py", line 231, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
TypeError: forward() takes from 1 to 12 positional arguments but 13 were given

Expected behavior

The models are converted correctly.
I think this bug is related to pytorch/pytorch#110439 , because ORTModelForVision2Seq tries to convert VisionEncoderDecoderModel, which contains **kwargs in forward().

@Yosshi999 Yosshi999 added the bug Something isn't working label Jun 15, 2024
@fxmarty
Copy link
Collaborator

fxmarty commented Jun 24, 2024

Hi @Yosshi999, thank you for the report. PyTorch 1.11 is more than 2 years old, do you face the same issue with a more recent version of pytorch?

@Yosshi999
Copy link
Author

Hi @Yosshi999, thank you for the report. PyTorch 1.11 is more than 2 years old, do you face the same issue with a more recent version of pytorch?

no. only 1.11.

@fxmarty
Copy link
Collaborator

fxmarty commented Jul 1, 2024

Thank you, is updating to torch==1.13 or torch>=2.0 an option to you?

@Yosshi999
Copy link
Author

Yes. I switched to 1.13 and this error no longer occurs. This issue is just to let you know the problem on 1.11, which is supported by optimum.
And sometimes older Jetson environments will require pytorch==1.11 so I think some users will be caught by this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants