Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTokenizer] Allow creation of tokenizers by tokenizer type #13668

Conversation

patrickvonplaten
Copy link
Contributor

What does this PR do?

This PR enables the Case #4 as discussed here: #13623 (comment)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for updating this new API!

Copy link
Contributor

@SaulLu SaulLu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the addition @patrickvonplaten 🙌

I just left 2 small questions in comments

Comment on lines 436 to 458
# If we have the tokenizer_type we can leverage it
if tokenizer_type is not None:
tokenizer_class = None
tokenizer_class_tuple = TOKENIZER_MAPPING_NAMES.get(tokenizer_type, None)

if tokenizer_class_tuple is None:
raise ValueError(
f"Passed `tokenizer_type` {tokenizer_type} does not exist. `tokenizer_type` should be one of "
f"{', '.join(c for c in TOKENIZER_MAPPING_NAMES.keys())}."
)

tokenizer_class_name, tokenizer_fast_class_name = tokenizer_class_tuple

if use_fast and tokenizer_fast_class_name is not None:
tokenizer_class = tokenizer_class_from_name(tokenizer_fast_class_name)

if tokenizer_class is None:
tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

if tokenizer_class is None:
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")

return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if all these added lines could not be put before the line 432. It seems to me that tokenizer_config and config_tokenizer_class are not used in this if and it would "save" loading the tokenizer_config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree! Changed it

tokenizer_class = tokenizer_class_from_name(tokenizer_class_name)

if tokenizer_class is None:
raise ValueError(f"Tokenizer class {tokenizer_class_name} is not currently imported.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not sure I figured out why this error message isn't exactly (about the existence of the tokenizer) like the one on the line 478.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokenizer_class_name has to exist as it's mapped from the tokenizer_type so the only reason this could fail is if the class cannot be important due to some missing packages like tokenizers or sentencepiece

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I agree with Lucile's comments.

@patrickvonplaten patrickvonplaten merged commit 8e908c8 into huggingface:master Sep 21, 2021
@patrickvonplaten patrickvonplaten deleted the add_from_tokenizer_type_auto_tok branch September 21, 2021 22:29
Narsil pushed a commit to Narsil/transformers that referenced this pull request Sep 25, 2021
stas00 pushed a commit to stas00/transformers that referenced this pull request Oct 12, 2021
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants