-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for other features for already supported models #14358
Added support for other features for already supported models #14358
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are looking really good - the code is much more elegant and modular than before 🤩 !
Before we merge, I think it would be good to do a few sanity checks with "real inputs" for the seq2seq models. For example, just checking that we get agreement with these examples from the docs would be nice:
|
||
@property | ||
def atol_for_validation(self) -> float: | ||
return 1e-2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool idea to allow different atol
values per model!
Does the tolerance of 1e-2
for BART reflect the work in progress on this model? (Naively, I would have expected 1e-3
or smaller)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think now this is because of the work in progress, will try things out with smaller values before merging.
return ordered_inputs | ||
|
||
@property | ||
def default_onnx_opset(self) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea to define default operator sets this way :)
EXTERNAL_DATA_FORMAT_SIZE_LIMIT, | ||
OnnxConfig, | ||
OnnxConfigWithPast, | ||
OnnxSeq2SeqConfigWithPast, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the end-user perspective, should one use OnnxSeq2SeqConfigWithPast
for all seq2seq models? If yes, we might want to explain this when we extend the documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes definitely!
src/transformers/onnx/config.py
Outdated
decoder_shape = ( | ||
batch, | ||
num_decoder_attention_heads, | ||
1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's a good idea to add a small comment here that we set the decoder sequence length to 1 here because we only use the last decoder_input_ids
when using pre-computed values with past_key_values
?
(It's probably obvious to people deeply familiar with the transformers
codebase, but might not be obvious to people trying to export their own models)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, I will add them.
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" | ||
|
||
for _ in range(min_num_layers): | ||
common_inputs["past_key_values"].append( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we should add a comment here to explain why past_key_values
involving tuples of 4 tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right!
tests/test_onnx_v2.py
Outdated
("LayoutLM", "microsoft/layoutlm-base-uncased", LayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig), | ||
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig), | ||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), | ||
PYTORCH_EXPORT_MODELS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much more elegant!
config.pad_token_id = tokenizer.eos_token_id | ||
|
||
model_class = FeaturesManager.get_model_class_for_feature(feature) | ||
model = model_class.from_config(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, would it make more sense to loading the model using a pretrained checkpoint instead of the random initialization from a config?
I think using pretrained weights would be a more realistic test of how the ONNX export is used in real applications.
But maybe we can leave this as a TODO for a follow-up PR since this one is getting pretty large :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this big refactoring and adding proper support for BART / mBART 🚀 !
I've manually tested T5, BART, and mBART with various outputs and the max absolute difference is between 1e-5 to 1e-4 which is perfectly fine IMO :)
Great work - I'm looking forward to building on top of this!
|
||
@property | ||
def outputs(self) -> Mapping[str, Mapping[int, str]]: | ||
if self.task in ["default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps I'm misunderstanding something, but why are default-with-past
and seq2seq-lm-with
past included in this list?
Looking at the if/else logic, it seems we extract the common_inputs
for past key values in the else clause, so I wonder if we should be using:
if self.task in ["default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past"]: | |
if self.task in ["default", "seq2seq-lm"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, actualy the task is never "with-past" as it is parsed before being passed to the OnnxConfig
.
So should definitely delete that.
import torch | ||
batch = common_inputs["input_ids"].shape[0] | ||
encoder_seq_length = common_inputs["input_ids"].shape[1] | ||
# decoder_seq_length = ordered_inputs["decoder_input_ids"].shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can delete this bit of dead code?
return common_inputs | ||
|
||
def _flatten_past_key_values_(self, flattened_output, name, idx, t): | ||
if self.task in ["default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to flatten past_key_values
if the task is default
or seq2seq-lm
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because you can have default
and self.use_past = True
(when you are doing default-with-past
).
As mentioned above, "default-with-past" will be parsed to something like OnnxConfig(task=feature.replace("-with-past", ""), use_past="with-past" in feature)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thank you for the clarification!
"default", | ||
"masked-lm", | ||
"sequence-classification", | ||
# "multiple-choice", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we delete this bit of dead code? (And similarly for the other occurrences of multiple-choice
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder because the plan is to add support for this too one day.
I do not know how much work that represents.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK let's keep it then :)
# ("T5", T5Config) | ||
} | ||
SUPPORTED_WITH_PAST_CONFIGS = {} | ||
# SUPPORTED_WITH_PAST_CONFIGS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these commented out configs now be included or are the unit tests not ready for these architectures yet?
src/transformers/onnx/config.py
Outdated
|
||
# Generate decoder inputs | ||
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( | ||
tokenizer, batch_size, 1, is_pair, framework |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following up on our discussion offline, should the sequence length be fixed at 1
in this base class? Would it be more appropriate to use seq_length
(or some multiple thereof to test differing encoder / decoder sequence lengths)?
…d models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c.
…d models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c.
…d models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c.
* Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import
* First commit to add MarianMT to ONNX * Now MarianModel.forward() automatically generates decoder_input_ids, like BartModel.forward() * Adjusted MarianOnnxConfig.inputs and outputs to work with seq2seq-lm feature * Style fix * Added support for other features for already supported models * Partial support for causal and seq2seq models * Partial support for causal and seq2seq models * Add default task for MarianMT ONNX * Remove automatic creation of decoder_input_ids * Extend inputs and outputs for MarianMT ONNX config * Add MarianMT to ONNX unit tests * Refactor * OnnxSeq2SeqConfigWithPast to support seq2seq models * Parameterized the onnx tests * Restored run_mlm.py * Restored run_mlm.py * [WIP] BART update * BART and MBART * Add past_key_values and fix dummy decoder inputs Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations. * Refactor MarianOnnxConfig to remove custom past_key_values logic * Fix quality * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Refactor Marian export to account for base changes * Fix copies * Implemented suggestions * Extend support for causal LM * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Revert "Revert "Added support for other features for already supported models (#14358)" (#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import * Remove ONNX model * Remove redundant class method * Tidy up imports * Fix quality * Refactor dummy input function * Add copied from statements to Marian config functions * Remove false copied from comments * Fix copy from comment Co-authored-by: Massimiliano Bruni <massimiliano.bruni@hcl.com> Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
…gface#14358) * Added support for other features for already supported models * Partial support for causal and seq2seq models * Partial support for causal and seq2seq models * OnnxSeq2SeqConfigWithPast to support seq2seq models * Parameterized the onnx tests * Restored run_mlm.py * Restored run_mlm.py * [WIP] BART update * BART and MBART * Added comments * Another sequence length of the past_key_values
* Revert "Revert "Added support for other features for already supported models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import
* First commit to add MarianMT to ONNX * Now MarianModel.forward() automatically generates decoder_input_ids, like BartModel.forward() * Adjusted MarianOnnxConfig.inputs and outputs to work with seq2seq-lm feature * Style fix * Added support for other features for already supported models * Partial support for causal and seq2seq models * Partial support for causal and seq2seq models * Add default task for MarianMT ONNX * Remove automatic creation of decoder_input_ids * Extend inputs and outputs for MarianMT ONNX config * Add MarianMT to ONNX unit tests * Refactor * OnnxSeq2SeqConfigWithPast to support seq2seq models * Parameterized the onnx tests * Restored run_mlm.py * Restored run_mlm.py * [WIP] BART update * BART and MBART * Add past_key_values and fix dummy decoder inputs Using a sequence length of 1 in generate_dummy_outputs() produces large discrepancies, presumably due to some hidden optimisations. * Refactor MarianOnnxConfig to remove custom past_key_values logic * Fix quality * Revert "Revert "Added support for other features for already supported models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Refactor Marian export to account for base changes * Fix copies * Implemented suggestions * Extend support for causal LM * Revert "Revert "Added support for other features for already supported models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Revert "Revert "Added support for other features for already supported models (huggingface#14358)" (huggingface#14679)" This reverts commit 0f4e39c. * is_torch_available test to avoid failing imports * sorting parameterize parameters to solve ERROR gw0 gw1 * tests fix * tests fix * GPT2 with past fix * Fixed stateful class attribute change that was breaking things when converting multiple models sequentially * Removed onnx file * Implemented suggestions * Fixed __init__ to resolve conflict with master * Remove commented import * Remove ONNX model * Remove redundant class method * Tidy up imports * Fix quality * Refactor dummy input function * Add copied from statements to Marian config functions * Remove false copied from comments * Fix copy from comment Co-authored-by: Massimiliano Bruni <massimiliano.bruni@hcl.com> Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
@michaelbenayoun @lewtun @Albertobegue |
Hey @girishnadiger-gep this PR was superseded by #14700 which has been merged sometime ago. Is there a specific issue or feature missing that you're interested in? |
Hi @lewtun , I've converted a BART Large model to onnx model using 'seq2seq-lm' feature, but thought i'm missing something here so asked in this forum. |
Ah for that you can probably adapt the example that I used for the Marian PR in #14586 FYI we also have a forum (https://discuss.huggingface.co/) which is better suited for these type of questions - we try to use GitHub issues for bug reports / feature requests |
What does this PR do?
This PR adds support for almost all the features available for already supported models.
Main contributions:
OnnxSeq2SeqConfigWithPast
: a new class inheriting fromOnnxConfigWithPast
designed specifically for seq2seq models, this should make things easier for the community to contribute.past_key_values
), that have been requested by the community (check the list of supported feautres below)Features now supported: