-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add XLNet OnnxConfig #17027
Add XLNet OnnxConfig #17027
Conversation
Hi @sijunhe Nice PR, but could you rebase tre branch to avoid getting all the recent commits on this PR ? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Any progress here? @lewtun |
@@ -1081,7 +1080,6 @@ def forward( | |||
output_attentions: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
**kwargs, # delete after depreciation warning is removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will break any other usage of this architecture, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to do this because **kwargs
breaks the onnx export, as I mentioned in the PR description. It did pass all the unit test and I think the deprecation warning has been up for a while.
if "use_cache" in kwargs: | ||
warnings.warn( | ||
"The `use_cache` argument is deprecated and will be removed in a future version, use `use_mems` instead.", | ||
FutureWarning, | ||
) | ||
use_mems = kwargs["use_cache"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably keep this for another PR but it's probably the right timing to change use_cache
to use_mems
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep this here. This has nothing to do really with this PR IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, ok I understand! You need to remove the kwargs
to make it work for ONNX. Hmm, I sadly don't think we can do this before Transformers v5. @sgugger @LysandreJik what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, no breaking change until v5 indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, let's keep this PR open until v5 is released. Nevertheless, thank you for working on this @sijunhe - let's revisit it once we're able to safely remove the kwargs
from the forward pass!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be awesome yes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback @LysandreJik - this is really helpful!
As you suggest, with monkey patching we could have something like the following inside the export_pytorch()
(and possibly export_tensorflow()
) methods:
model.forward = forward_without_kwargs(model.forward)
where forward_without_kwargs()
is a function that wraps the original forward pass to strip out the kwargs
.
WDYT @sijunhe - would you like to have a go at implementing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couldn't we otherwise just remove kwargs
and replace it with use_cache=None
and then raise a warning if use_cache is not None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would still ensure bcp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would expose the deprecated arg though. So if there is a solution without removing the kwargs, I'd prefer it.
Thanks for the review folks. I tried what @lewtun suggested about stripping the kwargs but I couldn't really make it work. Instead I took @patrickvonplaten's suggestion and replace |
Since it's an edge case I'm ok with this! Thanks for making the change @sijunhe - what do you think @LysandreJik @sgugger |
No, the param is not documented since it's deprecated, and it should stay that way IMO. |
If I'm not mistaken, can't we define a wrapper function to strip out from transformers import AutoModel
import inspect
import functools
def forward_without_kwargs(forward):
@functools.wraps(forward)
def wrapper(*args, **kwargs):
return forward(*args, **kwargs)
# Override signature and strip out kwargs
sig = inspect.signature(forward)
sig = sig.replace(parameters=tuple(sig.parameters.values())[:-1])
wrapper.__signature__ = sig
return wrapper
# Load an XLNet checkpoint
model = AutoModel.from_pretrained("xlnet-base-cased")
# Has kwargs
inspect.signature(model.forward)
# Has no kwargs
model.forward = forward_without_kwargs(model.forward)
inspect.signature(model.forward) This function could live in Of course, this would also need to be tested properly - just an idea :) |
Also fine with me |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
**kwargs
argument in theforward
function of theXLNet
models. Seems like the**kwargs
was on deprecation warning anyway and removing it didn't break any tests. Here is the reproduction and the error log of the OnnxExport if the**kwargs
argument doesn't get removed.Fixes #16308
Before submitting
Pull Request section?
to it if that's the case. ONNXConfig: Add a configuration for all available models #16308
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@chainyo for the OnnxConfig
@patrickvonplaten and @sgugger for the changes in
modeling_xlnet.py
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.