Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests #4510

Merged
merged 23 commits into from Jun 15, 2020

Conversation

n1t0
Copy link
Member

@n1t0 n1t0 commented May 22, 2020

Fix #4015

Edit @thomwolf: I morphed this in a large refactoring of the tokenizer code and test to make it more flexible and have a better API. Here is a summary of the changes.

Breaking change

There is no breaking change in the user-facing methods (encode, encode_plus, batch_encode_plus, tokenize, convert_XXX).

There is a breaking change in the internal methodprepare_for_model which is now a private method _prepare_for_model with a simplified signature.

A new main user-facing method: __call__ i.e. model_input = tokenizer(text, **kwargs)

The extended encoding methods encode_plus and batch_encode_plus methods had names that could be intimidating for first-time users.

A new main entry point is created as tokenizer.__call__ which wraps both methods. You can feed __call__ with single examples, a pair of sentence to encode together or batches of single/pair sentences.

The signature of __call__ is also a better fit for the 馃nlp library when it comes to batches of pairs of sequences since the first and second elements in pair of sentences are supplied as separate arguments (see below) instead of a zipped list of pairs like in batch_encode_plus.

While all the previously provided methods (encode, encode_plus, batch_encode_plus, tokenize, convert_XXX) are still supported without breaking changes, __call__ is now the recommended way to encode all types of inputs when tokenizer.encode (which only return the list of input indices for a single sentence) is not enough i.e. for every case beside simple demo purposes.

Here is how you should use this new entry point for encoding text in all the main use-cases:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# 1. When you encode "a single sentence"
encoded_input = tokenizer("Hello I'm a single sentence")
# { 'input_ids': [101, 8667, 146, 112, 182, 170, 1423, 5650, 102],
#   'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0],
#   'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

# 2. When you encode "a pair of sentences in a single input"
encoded_input = tokenizer("How old are you?", "I'm 6 years old")
# { 'input_ids': [101, 1731, 1385, 1132, 1128, 136, 102, 146, 112, 182, 127, 1201, 1385, 102],
#   'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
#   'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

# 3. When you encode "a batch of single sentences"
batch_sentences = ["Hello I'm a single sentence",
                   "And another sentence",
                   "And the very very last one"]
encoded_input = tokenizer(batch_sentences)
# { 'input_ids': [[101, 8667, 146, 112, 182, 170, 1423, 5650, 102],
#                 [101, 1262, 1330, 5650, 102],
#                 [101, 1262, 1103, 1304, 1304, 1314, 1141, 102]],
#   'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0],
#                      [0, 0, 0, 0, 0],
#                      [0, 0, 0, 0, 0, 0, 0, 0]],
#   'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1],
#                      [1, 1, 1, 1, 1],
#                      [1, 1, 1, 1, 1, 1, 1, 1]]}

# You can batch (to max sequence size) and truncate (to max model length)
# with `padding`and `truncation` (see more details in the next section on padding/truncation)
encoded_input = tokenizer(batch_sentences, padding=True, truncation=True)
# { 'input_ids': [[101, 8667, 146, 112, 182, 170, 1423, 5650, 102],
#                 [101, 1262, 1330, 5650, 102, 0, 0, 0, 0],
#                 [101, 1262, 1103, 1304, 1304, 1314, 1141, 102, 0]],
#   'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0],
#                      [0, 0, 0, 0, 0, 0, 0, 0, 0],
#                      [0, 0, 0, 0, 0, 0, 0, 0, 0]],
#   'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1],
#                      [1, 1, 1, 1, 1, 0, 0, 0, 0],
#                      [1, 1, 1, 1, 1, 1, 1, 1, 0]]}

# 4. When you encode "a batch of pair of sentences"
batch_of_second_sentences = ["I'm a sentence that goes with the first sentence",
                             "And I should be encoded with the second sentence",
                             "And I go with the very last one"]
encoded_input = tokenizer(batch_sentences,
                          batch_of_second_sentences,
                          padding=True,
                          truncation=True)
# { 'input_ids': [[101, 8667, 146, 112, 182, 170, 1423, 5650, 102, 146, 112, 182, 170, 5650, 1115, 2947, 1114, 1103, 1148, 5650, 102],
#                 [101, 1262, 1330, 5650, 102, 1262, 146, 1431, 1129, 12544, 1114, 1103, 1248, 5650, 102, 0, 0, 0, 0, 0, 0],
#                 [101, 1262, 1103, 1304, 1304, 1314, 1141, 102, 1262, 146, 1301, 1114, 1103, 1304, 1314, 1141, 102, 0, 0, 0, 0]],
#   'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#                      [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
#                      [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]],
#   'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#                      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
#                      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]}

Padding/truncation

The padding and truncation logic was simplified and improved to cover all the major uses-cases with the simplest possible API.

Here is how to do the two most common use-cases for truncation/padding:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
batch_sentences = ["Hello I'm a single sentence",
                   "And another sentence",
                   "And the very very last one"]

# 1. No truncation and no padding
encoded_input = tokenizer(batch_sentences)

# 2. Pad to the max sequence length inside the provided batch
# while truncating to the max input length acceptable by the model 
encoded_input = tokenizer(batch_sentences, truncation=True, padding=True)

The new API for padding and truncation uses three arguments to the encoding methods: padding, truncation and max_length. This new way to specify padding/truncation is available in all the user-facing encoding methods: encode, encode_plus, batch_ecode_plus and the newly provided __call__.

All the previously provided ways to do padding/truncation (truncation_strategy, max_length, pad_to_max_length) are still supported without breaking changes but we recommend to use the new API.

Here are the details of all the possible inputs to padding, truncation and max_length:

  • padding to control the padding (can be provided with a boolean or a string for finer-grained control). padding accepts the following values:

    • True or 'longest': pad to the longest sequence in the batch (or no padding if only a single sequence if provided),
    • 'max_length': pad to a max length specified in max_length or to the max acceptable input length for the model if no length is provided (max_length=None)
    • False or 'do_not_pad' (default): No padding (i.e. can output batch with sequences of uneven lengths)
  • truncation to control truncation (can be provided with a boolean or a string for finer-grained control). truncation accepts the following values:

    • True or 'only_first': truncate to a max length specified in max_length or to the max acceptable input length for the model if no length is provided (max_length=None). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
    • 'only_second': truncate to a max length specified in max_length or to the max acceptable input length for the model if no length is provided (max_length=None). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
    • 'longest_first': truncate to a max length specified in max_length or to the max acceptable input length for the model if no length is provided (max_length=None). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
    • False or 'do_not_truncate' (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
  • max_length to control the length of the padding/truncation (integer or None). max_length accepts the following values:

    • None (default): This will use the predefined model max length if required by one of the truncation/padding parameters. If the model has no specific max input length (e.g. XLNet) truncation/padding to max length is deactivated.
    • any integer value (e.g. 42): Use this specific maximum length value if required by one of the truncation/padding parameters.

Now here is a table summarizing the recommended way to setup padding and truncation as well as the previously provided way to do it (still supported but not recommended) in all cases.

If you use pair of inputs sequence in any of the following examples, you can replace truncation=True by a STRATEGY selected in ['only_first', 'only_second', 'longest_first'], i.e. truncation='only_second' or truncation= 'longest_first' to control how both sequence in the pair are truncated as detailed just before the table. We don't include all these variants for the sake of keeping the table not too long.

Truncation Padding Recommended way Previously provided (still supported but not recommended)
no truncation no padding tokenizer(batch_sentences) tokenizer.batch_encode_plus(batch_sentences)
no truncation padding to max sequence in batch tokenizer(batch_sentences, padding=True) or tokenizer(batch_sentences, padding='longest') tokenizer.batch_encode_plus(batch_sentences, pad_to_max_length=True)
no truncation padding to max model input length tokenizer(batch_sentences, padding='max_length') Not possible
no truncation padding to specific length tokenizer(batch_sentences, padding='max_length', max_length=42) Not possible
truncation to max model input length no padding tokenizer(batch_sentences, truncation=True) or tokenizer(batch_sentences, truncation=STRATEGY) tokenizer.batch_encode_plus(batch_sentences, max_length=tokenizer.max_len)
truncation to max model input length padding to max sequence in batch tokenizer(batch_sentences, padding=True, truncation=True) or tokenizer(batch_sentences, padding=True, truncation=STRATEGY) Not possible
truncation to max model input length padding to max model input length tokenizer(batch_sentences, padding='max_length', truncation=True) or tokenizer(batch_sentences, padding='max_length', truncation=STRATEGY) tokenizer.batch_encode_plus(batch_sentences, pad_to_max_length=True, max_length=tokenizer.max_len)
truncation to max model input length padding to specific length Not possible Not possible
truncation to specific length no padding tokenizer(batch_sentences, truncation=True, max_length=42) or tokenizer(batch_sentences, truncation=STRATEGY, max_length=42) tokenizer.batch_encode_plus(batch_sentences, max_length=42)
truncation to specific length padding to max sequence in batch tokenizer(batch_sentences, padding=True, truncation=True, max_length=42) or tokenizer(batch_sentences, padding=True, truncation=STRATEGY, max_length=42) Not possible
truncation to specific length padding to max model input length Not possible Not possible
truncation to specific length padding to specific length tokenizer(batch_sentences, padding='max_length', truncation=True, max_length=42) or tokenizer(batch_sentences, padding='max_length', truncation=STRATEGY, max_length=42) tokenizer.batch_encode_plus(batch_sentences, pad_to_max_length=True, max_length=42)

Pre-tokenized inputs

The tokenizers now accept pre-tokenized inputs, i.e. inputs which are already sliced in words. The main reason for implementing a specific track for this type of inputs is to be able to use the fast mapping methods in tokenizers which provide character<=>token<=>words mappings. This can be very handy to easily compute labels and extract predictions for instance for Named-Entity-Recognition (NER) or Part-of-Speech tagging (POS tagging).

If you want to use pre-tokenized inputs, just set is_pretokenized=True in any of the encoding methods. Here are some examples:

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
batch_sentences = [["Hello", "I'm", "a", "single", "sentence"],
                   ["And", "another", "sentence"],
                   ["And", "the", "very", "very", "last", "one"]]

encoded_input = tokenizer(batch_sentences, is_pretokenized=True)

# Pre-tokenized inputs can be used in all cases (single/pair/batch of single/batch of pairs) 
batch_of_second_sentences = ["I'm a sentence that goes with the first sentence".split(),
                             "And I should be encoded with the second sentence".split(),
                             "And I go with the very last one".split()]
encoded_input = tokenizer(batch_sentences,
                          batch_of_second_sentences,
                          is_pretokenized=True,
                          padding=True,
                          truncation=True)

Verbose

A new verbose argument is provided in all the encoding methods to silence all the warnings related to the length of the input as well as missing special tokens (e.g. missing padding or unknown token).

Code organization

tokenization_utils.py was starting to grow out of control and is now split into three files:

  • tokenization_utils.py hosts the code for the PreTrainedTokenizers
  • tokenization_utils_fast.py hosts the code for the PreTrainedTokenizersFast
  • tokenization_utils_base.py hosts the common methods for PreTrainedTokenizers and PreTrainedTokenizersFast (mostly the front API) in a newly created PretrainedTokenizerBase as well as all the common logic for special tokens (in SpecialMixin) and for the outputs of the encoding (in BatchEncoding).

Full testing of fast tokenizers

The fast tokenizers provided by the tokenizers library are now fully tested and follow the same testing pipeline as the python (slow) tokenizers. Additional consistency tests have been added comparing the outputs of the fast and slow tokenizers under various conditions.

TODO (following PRs)

  • Serialization for Fast tokenizers
  • Some edge cases for add_tokens on Fast tokenizers are not covered (spaces in tokens for byte-level and lower casing of the added tokens).

@n1t0 n1t0 requested review from mfuntowicz and thomwolf May 22, 2020 00:18
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Having a test for this would be welcome

setup.py Outdated
@@ -108,7 +108,7 @@
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.7.0",
"tokenizers == 0.8.0.dev0",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

馃憤

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's add some tests on tokenized inputs at the same time.

@thomwolf thomwolf changed the title Use tokenizers pre-tokenized pipeline [HUGE] Refactoring tokenizers backend - padding - truncation - pre-tokenized pipeline - fast tokenizers - tests Jun 13, 2020
@codecov
Copy link

codecov bot commented Jun 13, 2020

Codecov Report

Merging #4510 into master will increase coverage by 0.54%.
The diff coverage is 92.01%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4510      +/-   ##
==========================================
+ Coverage   76.89%   77.43%   +0.54%     
==========================================
  Files         128      130       +2     
  Lines       21854    21966     +112     
==========================================
+ Hits        16804    17010     +206     
+ Misses       5050     4956      -94     
Impacted Files Coverage 螖
src/transformers/modeling_tf_albert.py 75.33% <酶> (酶)
src/transformers/tokenization_utils.py 94.81% <酶> (+5.11%) 猬嗭笍
src/transformers/tokenization_utils_base.py 91.55% <91.55%> (酶)
src/transformers/tokenization_utils_fast.py 92.59% <92.59%> (酶)
src/transformers/__init__.py 99.14% <100.00%> (+0.01%) 猬嗭笍
src/transformers/tokenization_bert.py 90.45% <100.00%> (-0.80%) 猬囷笍
src/transformers/tokenization_gpt2.py 97.08% <100.00%> (+0.25%) 猬嗭笍
src/transformers/tokenization_openai.py 83.84% <100.00%> (+0.12%) 猬嗭笍
src/transformers/tokenization_transfo_xl.py 40.82% <100.00%> (+0.14%) 猬嗭笍
src/transformers/tokenization_xlm_roberta.py 95.23% <0.00%> (-2.39%) 猬囷笍
... and 11 more

Continue to review full report at Codecov.

Legend - Click here to learn more
螖 = absolute <relative> (impact), 酶 = not affected, ? = missing data
Powered by Codecov. Last update 9931f81...52a30d6. Read the comment docs.

@@ -62,7 +64,7 @@ def setUp(self):
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
def get_input_output_texts(self, tokenizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever use this arg?

# Test encode_plus for pretokenized inputs
output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
for key in output_p.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does assertDictEqual work?

@thomwolf
Copy link
Member

thomwolf commented Jun 15, 2020

Ok, I morphed this in a large refactoring of the tokenizer code and test to make it more flexible and have a better API.

Here is a summary of the changes:

  • there is now a new main user-facing method: __call__ i.e. model_input = tokenizer(text, **kwargs) which should be the main entry point for converting text in model inputs in the future,
  • the padding/truncation logic was refactored to cover more cases and make the most common-case more natural to access
  • pre-tokenized inputs (e.g. for NER or POS tagging) are handled a lot better
  • the backend code was refactored and split in several files.

There is no breaking change in the user-facing methods (encode, encode_plus, batch_encode_plus, tokenize, convert_XXX). There is a breaking change in the internal method prepare_for_model which is now a private method _prepare_for_model with a simplified signature.

All the details are given in the updated description of the PR.

cc @LysandreJik @julien-c @patrickvonplaten @sshleifer @mfuntowicz @yjernite @srush @mariamabarham @lhoestq @VictorSanh @jplu @stefan-it @BramVanroy

@BramVanroy
Copy link
Collaborator

BramVanroy commented Jun 15, 2020

I always love to see changes that improve the usability. I think using call is one that can really make things easier for people to use. I also like pre-tokenized inputs a lot, since most of my data is pre-tokenized anyway.

The changes are quite big to go over, so just checking: hopefully there are very clear error messages when users choose incompatible options when running the tokenization process. Making the tokenizer easier to use by having a single entry-point is great, but not so much if it can create more user mistakes that are not clear to the user. Clear error messages are key.

A feature request, that I discussed with someone before but I don't remember who, is that it would be nice if the tokenizers could have an optional device argument. If we use return_tensors, it should return the tensors immediately on the given devices, e.g.

encoded_on_device = tokenizer(["Hello world.", "Who likes cookies?"], device=torch.device("cuda:0"))
# or
encoded_on_device = tokenizer(["Hello world.", "Who likes cookies?"], device=training_args.device)

Might even allow different type of values like device integers or "cuda" or "cpu" strings, and so on.

Great job! Looking forward to using this in practice.

@jplu
Copy link
Contributor

jplu commented Jun 15, 2020

This is awesome!! Really great work and congratulations with this huge rework of the tokenizers!!!

It is a bit too huge to go through everything but as far as I can see, the way to use the tokenizers now are way more accessible, mostly the pre-tokenizerd part.

A feature request, that I discussed with someone before but I don't remember who, is that it would be nice if the tokenizers could have an optional device argument. If we use return_tensors, it should return the tensors immediately on the given devices

@BramVanroy I don't think it is the place here because it is not compliant with TF :) I think that the tokenizers should stay as much framework agnostic as possible otherwise if we start to say "if you want to use the tokenizer for PT do that, and for TF do this" it becomes more complicated to maintain. Of course this is only my opinion nothing more :)

@BramVanroy
Copy link
Collaborator

BramVanroy commented Jun 15, 2020

@BramVanroy I don't think it is the place here because it is not compliant with TF :) I think that the tokenizers should stay as much framework agnostic as possible otherwise if we start to say "if you want to use the tokenizer for PT do that, and for TF do this" it becomes more complicated to maintain. Of course this is only my opinion nothing more :)

But that's what we do for return_tensors anyway, right?

@jplu
Copy link
Contributor

jplu commented Jun 15, 2020

But that's what we do for return_tensors anyway, right?

Exactly, and I think the same about this parameter, it adds complexity, while this can be easily done afterward.

@BramVanroy
Copy link
Collaborator

Exactly, and I think the same about this parameter, it adds complexity, while this can be easily done afterward.

It is true that this can be done easily afterwards, but I suppose this is one of those cases: how much ease-of-use do you want your library to have while also taking into account the complexity of the library itself. My main argument is that from a usability perspective it would be awesome to be able to just provide your text to the tokenizer and you immediately get the encoded input back that you can feed to your model without having to do anything else. You then even do this:

out = model(**tokenizer(input_text, return_tensors="pt", device=device))

This isn't pretty but it illustrates my point that it makes usage very easy and also easy to understand. It removes a lot of booilerplate stuff that as a user you don't want to spend time on. On the other hand I definitely understand your point that this will lead to more complexity on the library's side. I'd be interested to hear other people's opinions about this.

@jplu
Copy link
Contributor

jplu commented Jun 15, 2020

how much ease-of-use do you want your library to have while also taking into account the complexity of the library itself.

This is definitely true, I fully agree :) And what you propose makes sense as well. I would be curious to hear other opinions too ^^

Copy link
Member Author

@n1t0 n1t0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

馃槏 A lot of well-needed changes. Very clean and thorough, great job @thomwolf!


is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean add_prefix_space=True?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oups 馃槄 yes of course!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really cool, I'm in love with the new API. Love the new docstrings as well, I offer some suggestions so that they're better formatted below.

Comment on lines +50 to +53
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
FULL_TOKENIZER_FILE = "tokenizer.json"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool

Comment on lines 180 to 189
""" Get the index of the word corresponding (i.e. comprising) to an encoded token
in a sequence of the batch.

Can be called as:
- self.token_to_word(token_index) if batch size is 1
- self.token_to_word(batch_index, token_index) if batch size is greater than 1

This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick doc suggestion

Suggested change
""" Get the index of the word corresponding (i.e. comprising) to an encoded token
in a sequence of the batch.
Can be called as:
- self.token_to_word(token_index) if batch size is 1
- self.token_to_word(batch_index, token_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
"""
Get the index of the word corresponding (i.e. comprising) to an encoded token
in a sequence of the batch.
Can be called as:
- self.token_to_word(token_index) if batch size is 1
- self.token_to_word(batch_index, token_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.

Comment on lines 219 to 231
""" Get the encoded token span corresponding to a word in the sequence of the batch.

Token spans are returned as a TokenSpan NamedTuple with:
start: index of the first token
end: index of the token following the last token

Can be called as:
- self.word_to_tokens(word_index) if batch size is 1
- self.word_to_tokens(batch_index, word_index) if batch size is greater or equal to 1

This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Get the encoded token span corresponding to a word in the sequence of the batch.
Token spans are returned as a TokenSpan NamedTuple with:
start: index of the first token
end: index of the token following the last token
Can be called as:
- self.word_to_tokens(word_index) if batch size is 1
- self.word_to_tokens(batch_index, word_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
"""
Get the encoded token span corresponding to a word in the sequence of the batch.
Token spans are returned as a TokenSpan NamedTuple with:
- start: index of the first token
- end: index of the token following the last token
Can be called as:
- self.word_to_tokens(word_index) if batch size is 1
- self.word_to_tokens(batch_index, word_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.

Comment on lines 264 to 272
""" Get the character span corresponding to an encoded token in a sequence of the batch.

Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string associated to the token
end: index of the character following the last character in the original string associated to the token

Can be called as:
- self.token_to_chars(token_index) if batch size is 1
- self.token_to_chars(batch_index, token_index) if batch size is greater or equal to 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Get the character span corresponding to an encoded token in a sequence of the batch.
Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string associated to the token
end: index of the character following the last character in the original string associated to the token
Can be called as:
- self.token_to_chars(token_index) if batch size is 1
- self.token_to_chars(batch_index, token_index) if batch size is greater or equal to 1
"""
Get the character span corresponding to an encoded token in a sequence of the batch.
Character spans are returned as a CharSpan NamedTuple with:
- start: index of the first character in the original string associated to the token
- end: index of the character following the last character in the original string associated to the token
Can be called as:
- self.token_to_chars(token_index) if batch size is 1
- self.token_to_chars(batch_index, token_index) if batch size is greater or equal to 1

Comment on lines 301 to 310
""" Get the index of the token in the encoded output comprising a character
in the original string for a sequence of the batch.

Can be called as:
- self.char_to_token(char_index) if batch size is 1
- self.char_to_token(batch_index, char_index) if batch size is greater or equal to 1

This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Get the index of the token in the encoded output comprising a character
in the original string for a sequence of the batch.
Can be called as:
- self.char_to_token(char_index) if batch size is 1
- self.char_to_token(batch_index, char_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
"""
Get the index of the token in the encoded output comprising a character
in the original string for a sequence of the batch.
Can be called as:
- self.char_to_token(char_index) if batch size is 1
- self.char_to_token(batch_index, char_index) if batch size is greater or equal to 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.

Comment on lines 336 to 345
""" Get the character span in the original string corresponding to given word in a sequence
of the batch.

Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string
end: index of the character following the last character in the original string

Can be called as:
- self.word_to_chars(word_index) if batch size is 1
- self.word_to_chars(batch_index, word_index) if batch size is greater or equal to 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Get the character span in the original string corresponding to given word in a sequence
of the batch.
Character spans are returned as a CharSpan NamedTuple with:
start: index of the first character in the original string
end: index of the character following the last character in the original string
Can be called as:
- self.word_to_chars(word_index) if batch size is 1
- self.word_to_chars(batch_index, word_index) if batch size is greater or equal to 1
"""
Get the character span in the original string corresponding to given word in a sequence
of the batch.
Character spans are returned as a CharSpan NamedTuple with:
- start: index of the first character in the original string
- end: index of the character following the last character in the original string
Can be called as:
- self.word_to_chars(word_index) if batch size is 1
- self.word_to_chars(batch_index, word_index) if batch size is greater or equal to 1

Comment on lines 373 to 382
""" Get the word in the original string corresponding to a character in the original string of
a sequence of the batch.

Can be called as:
- self.char_to_word(char_index) if batch size is 1
- self.char_to_word(batch_index, char_index) if batch size is greater than 1

This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Get the word in the original string corresponding to a character in the original string of
a sequence of the batch.
Can be called as:
- self.char_to_word(char_index) if batch size is 1
- self.char_to_word(batch_index, char_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.
"""
Get the word in the original string corresponding to a character in the original string of
a sequence of the batch.
Can be called as:
- self.char_to_word(char_index) if batch size is 1
- self.char_to_word(batch_index, char_index) if batch size is greater than 1
This method is particularly suited when the input sequences are provided as
pre-tokenized sequences (i.e. words are defined by the user). In this case it allows
to easily associate encoded tokens with provided tokenized words.


is_pretokenized = kwargs.get("is_pretokenized", False)
assert self.add_prefix_space or not is_pretokenized, (
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"You need to instantiate GPT2TokenizerFast with add_prefix_space=False "
"You need to instantiate GPT2TokenizerFast with add_prefix_space=True "

""" Base classes common to both the slow and the fast tokenization classes:
PreTrainedTokenizerBase (host all the user fronting encoding methodes)
Special token mixing (host the special tokens logic) and
BatchEncoding (wrap the dictionnary of output with special method for the Fast tokenizers)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
BatchEncoding (wrap the dictionnary of output with special method for the Fast tokenizers)
BatchEncoding (wrap the dictionary of output with special method for the Fast tokenizers)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only French and English could agree on the number of n's they use, then that'd remove some confusion. 馃槃


class BatchEncoding(UserDict):
""" BatchEncoding hold the output of the encode and batch_encode methods (tokens, attention_masks, etc).
This class is derived from a python Dictionary and can be used as a dictionnary.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This class is derived from a python Dictionary and can be used as a dictionnary.
This class is derived from a python Dictionary and can be used as a dictionary.

Can be called as:

- ``self.word_to_chars(word_index)`` if batch size is 1
- ``self.word_to_chars(batch_index, word_index)`` if batch size is greater or equal to 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I look at this example: word_index is mandatory in both cased and batch_index is optionally. However, the type hints suggest, that word_index is optional (see line 377). I think this should be corrected then :)

Can be called as:

- ``self.char_to_word(char_index)`` if batch size is 1
- ``self.char_to_word(batch_index, char_index)`` if batch size is greater than 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs some clarification: char_index is mandatory here, but type hint suggests something else.

"Defaulting to 'only_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you may want to check this is the right behavior."
)
truncation = "only_first"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a pedantic comment, but this could be dangerous, because truncation is originally a boolean.

Falsy/Truthy comparison is done later with:

elif truncation is not False:
            if truncation is True:
                truncation_strategy = (
                    TruncationStrategy.ONLY_FIRST
                )  # Default to truncate the first sequences in pairs of inputs
            else:
                truncation_strategy = TruncationStrategy(truncation)

and we will land in the truncation_strategy = TruncationStrategy(truncation) branch.

But then truncation_strategy will be same as used in the if truncation is True branch - so it is a kind of "hidden" duplicate code here 馃

return_attention_mask: Optional[bool] = None,
verbose: bool = True,
) -> dict:
""" Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
""" Pad encoded inputs (on left/right and up to predefined legnth or max length in the batch)
""" Pad encoded inputs (on left/right and up to predefined length or max length in the batch)

batch_size = len(encoding_or_batch["input_ids"])
assert all(
len(v) == batch_size for v in encoding_or_batch.values()
), "Some items in the output dictionnary have a different batch size than others."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
), "Some items in the output dictionnary have a different batch size than others."
), "Some items in the output dictionary have a different batch size than others."

)

# Return tensor is None, then we can remove the leading batch axis
# Overfolwing tokens are returned as a batch of output so we keep them in this case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Overfolwing tokens are returned as a batch of output so we keep them in this case
# Overflowing tokens are returned as a batch of output so we keep them in this case

@LysandreJik
Copy link
Member

As seen with @thomwolf, will merge this PR as soon as the tests show all green. I'm updating all the library's docstrings to showcase best practices in a second PR.

@yuhongqian
Copy link

Thanks for the update! I was writing my own tokenizer for some special inputs and saw the implementation for the longest_first truncation. Is there any reason why tokens are truncated one by one? It seems more efficient to truncate the longer one to the same length as the shorter one, and then truncate the same number of tokens from both of them. In this way, we need only 3 array slices in total, saving a lot of loops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fast Tokenizers: batch_encode_plus error
10 participants