Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sentinel token IDs in data collator for Flax T5 pretraining script #14477

Merged

Conversation

rahuln
Copy link
Contributor

@rahuln rahuln commented Nov 21, 2021

What does this PR do?

Modifies the sentinel token IDs used in the data collator for the Flax T5 pretraining script so that they go in decreasing order starting at len(tokenizer) - 1, which matches the original T5 code.

Fixes #14282

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@LysandreJik
Copy link
Member

Hey @rahuln! I'm pinging @patrickvonplaten for review as you have worked with him until now, but please note that he's off until next week so he'll review your PR when he's back! Thanks for your understanding.

@patrickvonplaten
Copy link
Contributor

Great! Thanks a lot for digging into this issue and fixing it

@patrickvonplaten patrickvonplaten merged commit 8332327 into huggingface:master Nov 29, 2021
@@ -290,7 +290,7 @@ def create_sentinel_ids(self, mask_indices):
start_indices[:, 0] = mask_indices[:, 0]

sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that line not completely unecessary since we have the line right after?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the next line makes use of the just changed sentinel_ids parrameter no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry what I meant is those two lines should be summarizable to

sentinel_ids = np.where(start_indices != 0, (len(self.tokenizer) - np.cumsum(start_indices, axis=-1)), 0)

Ie what is 0 is kept at 0 and what's not 0 is given a non 0 value, which means that the next where operation uses the same segmentation and thus overrides the values.

cc @patil-suraj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mismatch between sentinel token IDs from T5 data collator and T5 tokenizer
4 participants