-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix tied weight for Bart (for BC) #42355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx a lot, checked locally (+ with other different fixes) and the integration tests pass then. Just 2 nits but feel free to ignore, not super important
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a FIXME/TODO here to cleanup after allowing the reverse direction in tied weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to remove all model-specific tie_weights anyway at that time, so fine like this IMO. A TODO does not add much value as we need to search for it and we don't always do anyway
| """ | ||
| ) | ||
| # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS | ||
| # Except `tie_weights()`: everything else Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These other classes like QnA are always a mess in legacy models 😢
Tbh, I don't see much being copied in the first place: The forward uses ignore copy and then only init and resize_xxx remain if I see it correctly. I'd pro to just remove the copies altogether or just directly use them on the 3 ish functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh indeed, did not notice that forward was skipping copy, good catch! Will update on the separate functions then!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bart, bigbird_pegasus |
What does this PR do?
As per the title. Bart used to check which weight was present instead of simply using the default one, as some main checkpoint have the wrong one saved. See here before we refactored the tied weight
cc @vasqu as you noticed the issue first!