Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx fix test #10663

Merged
merged 6 commits into from Mar 11, 2021
Merged

Onnx fix test #10663

merged 6 commits into from Mar 11, 2021

Conversation

mfuntowicz
Copy link
Member

GPT2 past_keys_values format seems to have changed since last time I checked, now exporting for each layer tuple with 2 elements.

PyTorch's ONNX exporter doesn't seem to handle this format, so it was crashing with an error.

The PR assumes we don't currently support exporting past_keys_values for GPT2 and then disable the return of such values when constructing the model.

In order to support this behavior, pipeline() now ha a model_kwargs: Dict[str, Any] parameter which forwards the dict of parameters to model's from_pretrained(..., **model_kwargs).

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

@@ -38,19 +38,23 @@ def forward(self, input_ids, some_other_args, token_type_ids, attention_mask):


class OnnxExportTestCase(unittest.TestCase):
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing roberta-base is on purpose here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, speeding up things a bit, Roberta and Bert share the exact same graph, so basically it's testing the same things twice.

@LysandreJik
Copy link
Member

Merging now to rebase the slow tests and re-run them.

@LysandreJik LysandreJik merged commit 3ab6820 into master Mar 11, 2021
@LysandreJik LysandreJik deleted the onnx_fix_test branch March 11, 2021 18:38
Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
* Allow to pass kwargs to model's from_pretrained when using pipeline.

* Disable the use of past_keys_values for GPT2 when exporting to ONNX.

* style

* Remove comment.

* Appease the documentation gods

* Fix style

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants