Skip to content

Fix loss computation in TFBertForPreTraining#17898

Merged
Rocketknight1 merged 1 commit intomainfrom
fix_tf_bert_pretraining_loss
Jun 28, 2022
Merged

Fix loss computation in TFBertForPreTraining#17898
Rocketknight1 merged 1 commit intomainfrom
fix_tf_bert_pretraining_loss

Conversation

@Rocketknight1
Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 commented Jun 27, 2022

With thanks to @Sreyan88 for writing up a clean bug report and reproducer, and to @ydshieh for locating the problematic code!

Our hf_compute_loss() function for TFBertForPreTraining was incorrect. However, it still appeared to work when the number of masked positions was evenly divisible by the batch size. Other, more commonly-used models like TFBertForMaskedLM do not have this issue.

The problem was incorrect handling of the reduction for the masked loss, so I took the opportunity to rewrite the function in modern TF. All shapes are now static in the rewritten function as well, which means it should now compile with XLA.

Fixes #17883

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jun 27, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing. Is this copied somewhere else too or just in BERT?

Copy link
Copy Markdown
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

@Rocketknight1
Copy link
Copy Markdown
Member Author

@sgugger I believe it's unique to BERT, because I tried searching the codebase for any similar lines and it couldn't find any. I suspect this is how it stayed undetected for so long - it uses the NSP loss and people generally don't train with that anymore.

@Rocketknight1 Rocketknight1 merged commit 0094565 into main Jun 28, 2022
@Rocketknight1 Rocketknight1 deleted the fix_tf_bert_pretraining_loss branch June 28, 2022 11:44
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Exception encountered when calling layer "tf_bert_for_pre_training" (type TFBertForPreTraining)

4 participants