Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPT-J onnx conversion #16780

Merged
merged 5 commits into from
Apr 21, 2022
Merged

Conversation

chainyo
Copy link
Contributor

@chainyo chainyo commented Apr 14, 2022

What does this PR do?

Fix some problems encountered while converting a GPT-J model to Onnx.
Thanks to @ri938 who found where to fix bugs (on 馃 Discord).

Models:
gpt2: @patrickvonplaten, @LysandreJik
and @lewtun because you reviewed the first PR for GPT-J Onnx Config, here #16274

I'm currently uploading a fully converted EleutherAI/gpt-j-6B model to the hub which demonstrate that the conversion command line worked with these fixes. Find it here

Here is the command I used (I had to fix atol to 1e-04 because 1e-05 was not true while validating the model):

python -m transformers.onnx --model=EleutherAI/gpt-j-6B --feature=causal-lm --atol=1e-04 onnx/

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 14, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix @chainyo! This change to sinusoid_inp looks good to me, but I'd like @patil-suraj to comment on whether he thinks it would have any negative consequences

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for the fix!

src/transformers/models/gptj/modeling_gptj.py Outdated Show resolved Hide resolved
@chainyo
Copy link
Contributor Author

chainyo commented Apr 19, 2022

Fine I'm going to accept @lewtun 's suggestion. 馃

I see there are some failing checks, and I also see that there are problems with linting is it normal ?

@patil-suraj
Copy link
Contributor

The run_tests_torch_and_tf failure is unrelated, to fix the check_code_quality test, run the make fixup command and push again.

@chainyo
Copy link
Contributor Author

chainyo commented Apr 19, 2022

The run_tests_torch_and_tf failure is unrelated, to fix the check_code_quality test, run the make fixup command and push again.

Last time I tried to run make fixup it changed linting on more than 87 files not related to the PR so I reverted the fixup

@lsn1106
Copy link

lsn1106 commented Apr 20, 2022

@chainyo
hi, thanks for your contribution. I tested it with my basic gptj model and I think it's working pretty well.

But I don't think it's working well when I tested it with a model that was extracted using a method called 'use cache' or 'use past'. Can you give me an example or check if there's anything wrong with my code??

Below is the test i did

ort_session = onnxruntime.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])

#check session's input
for ort_session_input in ort_session.get_inputs():
    print(ort_session_input.name, ort_session_input.shape, ort_session_input.type)
#input_ids ['batch', 'sequence'] tensor(int64)
#past_key_values.0.key ['batch', 16, 'past_sequence + sequence', 256] tensor(float)
#past_key_values.0.value ['batch', 16, 'past_sequence + sequence', 256] tensor(float) 
#...
#past_key_values.27.key ['batch', 16, 'past_sequence + sequence', 256] tensor(float)
#past_key_values.27.value ['batch', 16, 'past_sequence + sequence', 256] tensor(float) 
#attention_mask ['batch', 'past_sequence + sequence'] tensor(float)

input_txt_list = [
    'text for test', 
    'gptj'
]
ort_input = make_onnx_inputs(input_txt_list)
for k,v in ort_input.items():
    print(k,v.size(),v.dtype)
#input_ids torch.Size([2, 3]) torch.int64
#past_key_values.0.key torch.Size([2, 16, 0, 256]) torch.float32
#past_key_values.0.value torch.Size([2, 16, 0, 256]) torch.float32
#...
#past_key_values.27.key torch.Size([2, 16, 0, 256]) torch.float32
#past_key_values.27.value torch.Size([2, 16, 0, 256]) torch.float32
#attention_mask torch.Size([2, 3]) torch.float32

#TypeError
ort_output = ort_session.run(None, ort_input)

And this is the error message

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_1166/2757688042.py in <module>
----> 1 ort_output = ort_session.run(None, ort_input)

/opt/conda/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
    190             output_names = [output.name for output in self._outputs_meta]
    191         try:
--> 192             return self._sess.run(output_names, input_feed, run_options)
    193         except C.EPFail as err:
    194             if self._enable_fallback:

TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f319b7e3030>, ['logits', 'present.0.key', 'present.0.value', 'present.1.key', ...'present.27.key', 'present.27.value'], {{'input_ids': tensor([[24496,  1956, 15560], [    3,     3, 18566]]), 'past_key_values.0.key': tensor([], size=(2, 16, 0, 256)), 'past_key_values.0.value': tensor([], size=(2, 16, 0, 256)), 
...
'past_key_values.27.key': tensor([], size=(2, 16, 0, 256)), 'past_key_values.27.value': tensor([], size=(2, 16, 0, 256)), 'attention_mask': tensor([[1.,1.,1.],[0.,0.,1.]])}, None

@chainyo
Copy link
Contributor Author

chainyo commented Apr 20, 2022

@chainyo hi, thanks for your contribution. I tested it with my basic gptj model and I think it's working pretty well.

But I don't think it's working well when I tested it with a model that was extracted using a method called 'use cache' or 'use past'. Can you give me an example or check if there's anything wrong with my code??

Hi @lsn1106 could you try to add your model to netron.app and check the expected inputs by clicking on the first layer ?

I'm not sure, but maybe use_past is a feature that is only implemented under the hood in the Transformers library which is not available in the Onnxruntime library.

@lsn1106
Copy link

lsn1106 commented Apr 20, 2022

@chainyo
Thank you for your kind advice. Maybe I should refer to this 锟絞ithub code [link]

@lewtun
Copy link
Member

lewtun commented Apr 20, 2022

Hey @lsn1106 looking at your error in ORT

TypeError: run(): incompatible function arguments. The following argument types are supported:

it seems that you're not passing inputs with the correct types. What happens if you cast your inputs to NumPy arrays and ensure that ort_input truly is a dict?

@lsn1106
Copy link

lsn1106 commented Apr 20, 2022

@lewtun
i've already tried that but same error occured. thank you :)

@lewtun
Copy link
Member

lewtun commented Apr 20, 2022

Last time I tried to run make fixup it changed linting on more than 87 files not related to the PR so I reverted the fixup

Maybe you can rebase on main and run make fixup again? I'm not entirely sure why it should lint so many files, but this might resolve the problem

@lewtun
Copy link
Member

lewtun commented Apr 20, 2022

@lsn1106 would you mind sharing a reproducible code snippet that shows how you export the model, are creating the inputs for ORT, etc?

@chainyo
Copy link
Contributor Author

chainyo commented Apr 20, 2022

Last time I tried to run make fixup it changed linting on more than 87 files not related to the PR so I reverted the fixup

Maybe you can rebase on main and run make fixup again? I'm not entirely sure why it should lint so many files, but this might resolve the problem

Well it seems to be solved, thanks!

@lewtun
Copy link
Member

lewtun commented Apr 20, 2022

I think the last thing we need to do is run make style && make quality and then this should be good to go 馃殌 !

@chainyo
Copy link
Contributor Author

chainyo commented Apr 20, 2022

I think the last thing we need to do is run make style && make quality and then this should be good to go rocket !

Yes sorry the first make fixup didn't run black, it should be good now!

@lewtun
Copy link
Member

lewtun commented Apr 21, 2022

Great, thanks for fixing the style issues! Merging this since the issue reported by @lsn1106 is unrelated to the fix provided by this PR

@lewtun lewtun merged commit 0b1e0fc into huggingface:main Apr 21, 2022
Narsil pushed a commit to Narsil/transformers that referenced this pull request Apr 21, 2022
* add gptj to TOKENIZER_MAPPING_NAMES

* fix int32 to float to avoid problem in onnx

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* add gptj to TOKENIZER_MAPPING_NAMES

* fix int32 to float to avoid problem in onnx

* Update src/transformers/models/gptj/modeling_gptj.py

Co-authored-by: ChainYo <t.chaigneau.tc@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants