Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add final_layer_norm to OPT model #17785

Merged
merged 8 commits into from Jun 21, 2022
Merged

Add final_layer_norm to OPT model #17785

merged 8 commits into from Jun 21, 2022

Conversation

thomasw21
Copy link
Contributor

@thomasw21 thomasw21 commented Jun 20, 2022

What does this PR do?

Fixes #17653 , #17545

OPT models have a final_layer_norm: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L466-L477
So we update HF models + conversion script to take in account that missing layer norm.

Test on OPT-125m (restored.pt file from patrickvonplaten/opt_metaseq_125m):

>>> model_path="fixed_opt_125m"
>>> prompt="Hello my name is"
>>> log_probs_with_ppl(model_path, prompt)
Input torch.Size([1, 5])
Logits torch.Size([1, 5, 50272])
torch.return_types.max(
values=tensor([[0.2398, 0.2326, 0.3332, 0.9363, 0.0097]], grad_fn=<MaxBackward0>),
indices=tensor([[ 100,    6,  766,   16, 1236]]))
argmax probility: [[0.23982257 0.23258895 0.33315504 0.9362957  0.00967377]]
argmax log probability: [[-1.4278558  -1.4584825  -1.0991473  -0.06582398 -4.6383367 ]]
argmax tokens: I, name is j
cross entropy loss: 4.051314830780029
ppl: 57.47297286987305

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 20, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM we should update the tf and flax checkpoints

@younesbelkada
Copy link
Contributor

Thank you very much for the fix! I think that we'll have to change the generation tests a bit for the other models as well

@patrickvonplaten
Copy link
Contributor

Great find @thomasw21 - thanks a lot for fixing it!

Think the checkpoints were then also incorrectly loaded inside the metaseq codebase - could you maybe double check that the following script gives identical results between fairseq and transformers: https://huggingface.co/patrickvonplaten/opt_metaseq_125m -> The logits should match there (maybe an incorrect configuration in the metaseq model?)

Also could you please update the slow model tests?

@ArthurZucker
Copy link
Collaborator

@thomasw21 I can update the tests and check the outputs if you want

@thomasw21
Copy link
Contributor Author

thomasw21 commented Jun 21, 2022

@patrickvonplaten from what I understood logits comparison equality test were only done in 350m? @younesbelkada
I actually manually converted restored.pt from https://huggingface.co/patrickvonplaten/opt_metaseq_125m using the updated conversion script.

@ArthurZucker if you have the bandwidth, I'd appreciate it! Thanks!

@gante gante mentioned this pull request Jun 21, 2022
4 tasks
patrickvonplaten added a commit to patrickvonplaten/metaseq that referenced this pull request Jun 21, 2022
If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785

If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932

So it's absolutely crucial that we include this variable.

I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 
https://huggingface.co/models?other=opt_metasq
@@ -492,7 +492,11 @@ def __init__(self, config: OPTConfig):
else:
self.project_in = None

self.layer_norm = None
if config.do_layer_norm_before:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if config.do_layer_norm_before:
if config.do_layer_norm_before and not config._remove_final_layer_norm:

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing it! Let's merge the PR once the checkpoints are correctly uploaded under @ArthurZucker's namespace

@thomasw21
Copy link
Contributor Author

thomasw21 commented Jun 21, 2022

@patrickvonplaten Yep I've looked at the changes with your comment, feel free to merge those : D

@younesbelkada
Copy link
Contributor

When releasing the patch can we merge at the same time #17437 ? The problem of NaNs for batched generation still persists with this fix, but is resolved with #17437

stephenroller added a commit to facebookresearch/metaseq that referenced this pull request Jun 21, 2022
…checkpoint to run correctly (#164)

* Singleton checkpoint needs to include decoder.version

If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785

If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932

So it's absolutely crucial that we include this variable.

I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 
https://huggingface.co/models?other=opt_metasq

* Update convert_to_singleton.py

Co-authored-by: Stephen Roller <roller@fb.com>
@ArthurZucker
Copy link
Collaborator

BTW @patrickvonplaten do you have the expected values for the slow test?

@patrickvonplaten
Copy link
Contributor

BTW @patrickvonplaten do you have the expected values for the slow test?

Corrected the tests as well now

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating, fixing and humoring my push for backward compatibility :-)

@patrickvonplaten
Copy link
Contributor

Good job @thomasw21 !

@patrickvonplaten patrickvonplaten merged commit abc400b into main Jun 21, 2022
@patrickvonplaten patrickvonplaten deleted the thomas/fix_opt branch June 21, 2022 18:26
sgugger pushed a commit that referenced this pull request Jun 21, 2022
* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
vfbd pushed a commit to VE-FORBRYDERNE/mesh-transformer-jax that referenced this pull request Jun 21, 2022
sriniiyer pushed a commit to facebookresearch/metaseq that referenced this pull request Jun 21, 2022
…checkpoint to run correctly (#164)

* Singleton checkpoint needs to include decoder.version

If we don't transfer the `"decoder.version"` to the singleton checkpoint, a very sneaky bug happens which was found by @thomasw21 as part of this PR: huggingface/transformers#17785

If the `decoder.version` param is not present in the state_dict it follows that upon loading the single-ton checkpoint the loaded layer_norm is set to `None` here: https://github.com/facebookresearch/metaseq/blob/e0c4f6b0e4c523906ad8d561f727e3f2ac3a8e73/metaseq/models/transformer.py#L932

So it's absolutely crucial that we include this variable.

I will update all of the converted HF checkpoints here later today and then I think we can be sure that OPT works correctly 🥳 
https://huggingface.co/models?other=opt_metasq

* Update convert_to_singleton.py

Co-authored-by: Stephen Roller <roller@fb.com>
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 29, 2022
* Add final_layer_norm to OPT model

* Add JAX and TF version

* Fix Keras name

* Woops

* Allow for non breaking change

* Apply suggestions from code review

* add tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Abnormal behavior of OPT except OPT-350m
6 participants