New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add final_layer_norm to OPT model #17785
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM we should update the tf
and flax checkpoints
Thank you very much for the fix! I think that we'll have to change the generation tests a bit for the other models as well |
Great find @thomasw21 - thanks a lot for fixing it! Think the checkpoints were then also incorrectly loaded inside the metaseq codebase - could you maybe double check that the following script gives identical results between fairseq and transformers: https://huggingface.co/patrickvonplaten/opt_metaseq_125m -> The logits should match there (maybe an incorrect configuration in the metaseq model?) Also could you please update the slow model tests? |
@thomasw21 I can update the tests and check the outputs if you want |
@patrickvonplaten from what I understood logits comparison equality test were only done in 350m? @younesbelkada @ArthurZucker if you have the bandwidth, I'd appreciate it! Thanks! |
src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
Show resolved
Hide resolved
If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785 If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932 So it's absolutely crucial that we include this variable. I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 https://huggingface.co/models?other=opt_metasq
@@ -492,7 +492,11 @@ def __init__(self, config: OPTConfig): | |||
else: | |||
self.project_in = None | |||
|
|||
self.layer_norm = None | |||
if config.do_layer_norm_before: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if config.do_layer_norm_before: | |
if config.do_layer_norm_before and not config._remove_final_layer_norm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing it! Let's merge the PR once the checkpoints are correctly uploaded under @ArthurZucker's namespace
src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
Show resolved
Hide resolved
@patrickvonplaten Yep I've looked at the changes with your comment, feel free to merge those : D |
…checkpoint to run correctly (#164) * Singleton checkpoint needs to include decoder.version If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785 If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932 So it's absolutely crucial that we include this variable. I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 https://huggingface.co/models?other=opt_metasq * Update convert_to_singleton.py Co-authored-by: Stephen Roller <roller@fb.com>
BTW @patrickvonplaten do you have the expected values for the slow test? |
Corrected the tests as well now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating, fixing and humoring my push for backward compatibility :-)
Good job @thomasw21 ! |
* Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
…checkpoint to run correctly (#164) * Singleton checkpoint needs to include decoder.version If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785 If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932 So it's absolutely crucial that we include this variable. I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 https://huggingface.co/models?other=opt_metasq * Update convert_to_singleton.py Co-authored-by: Stephen Roller <roller@fb.com>
* Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Add final_layer_norm to OPT model * Add JAX and TF version * Fix Keras name * Woops * Allow for non breaking change * Apply suggestions from code review * add tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
What does this PR do?
Fixes #17653 , #17545
OPT models have a final_layer_norm: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L466-L477
So we update HF models + conversion script to take in account that missing layer norm.
Test on OPT-125m (
restored.pt
file frompatrickvonplaten/opt_metaseq_125m
):Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?