Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bert Batch Encode Plus adding an extra [SEP] #3502

Closed
creat89 opened this issue Mar 28, 2020 · 6 comments 路 Fixed by #3517
Closed

Bert Batch Encode Plus adding an extra [SEP] #3502

creat89 opened this issue Mar 28, 2020 · 6 comments 路 Fixed by #3517
Assignees

Comments

@creat89
Copy link

creat89 commented Mar 28, 2020

馃悰 Bug

Information

I'm using bert-base-multilingual-cased tokenizer and model for creating another model. However, the batch_encode_plus is adding an extra [SEP] token id in the middle.

The problem arises when using:

  • Specific strings to encode, e.g. 16., 3., 10.,
  • The bert-base-multilingual-cased tokenizer is used beforehand to tokenize the previously described strings and
  • The batch_encode_plus is used to convert the tokenized strings

In fact, batch_encode_plus will generate an input_ids list containing two [SEP], such as in [101, 10250, 102, 119, 102]

I have seen similar issues, but they don't indicate the version of transformers:

#2658
#3037

Thus, I'm not sure if it is related to transformers version 2.6.0

To reproduce

Steps to reproduce the behavior (simplified steps):

  1. Have a string of type 16. or 6.
  2. Use tokens = bert_tokenizer.tokenize("16.")
  3. Use bert_tokenizer.batch_encode_plus([tokens])

You can reproduce the error with this code

from transformers import BertTokenizer
import unittest

class TestListElements(unittest.TestCase):

    def setUp(self):

        bert_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

        problematic_string = "16."

        tokens = bert_tokenizer.tokenize(problematic_string)

        self.encoded_batch_1 = bert_tokenizer.batch_encode_plus([tokens])    #list[list[]]
        self.encoded_batch_2 = bert_tokenizer.batch_encode_plus([problematic_string]) #list[]
        self.encoded_tokens_1 = bert_tokenizer.encode_plus(problematic_string)
        self.encoded_tokens_2 = bert_tokenizer.encode_plus(tokens)

    def test_tokens_vs_tokens(self):
        self.assertListEqual(self.encoded_tokens_1["input_ids"], self.encoded_tokens_2["input_ids"])

    def test_tokens_vs_batch_string(self):
        self.assertListEqual(self.encoded_tokens_1["input_ids"], self.encoded_batch_2["input_ids"][0])

    def test_tokens_vs_batch_list_tokens(self):
        self.assertListEqual(self.encoded_tokens_1["input_ids"], self.encoded_batch_1["input_ids"][0])

if __name__ == "__main__":
    unittest.main(verbosity=2)

The code will break at test test_tokens_vs_batch_list_tokens, with the following summarized output:

- [101, 10250, 119, 102]
+ [101, 10250, 102, 119, 102]

Expected behavior

The batch_encode_plus should always produce the same input_ids no matter whether we pass them a list of tokens or a list of strings.

For instance, for the string 16. we should get always [101, 10250, 119, 102]. However, using batch_encode_plus we get [101, 10250, 102, 119, 102] if we pass them an input already tokenized.

Environment info

  • transformers version: 2.6.0
  • Platform: Linux (Manjaro)
  • Python version: Python 3.8.1 (default, Jan 8 2020, 22:29:32)
  • PyTorch version (GPU?): 1.4.0 (True)
  • Tensorflow version (GPU?): ---
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: False
@patrickvonplaten patrickvonplaten self-assigned this Mar 28, 2020
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 28, 2020

Hi @creat89,

Thanks for posting this issue!

You are correct there is some inconsistent behavior here.

  1. We should probably in general not allow using batch_encode_plus() of a simple string. For this the encode_plus() function should be used.
  2. It seems like there is an inconsistency between encode_plus([string]) and encode_plus(string). This should probably be fixed.

@creat89
Copy link
Author

creat89 commented Mar 28, 2020

Well, the issue not only happens with a simple string. In my actual code I was using a batch of size 2. However, I just used a simple example to demonstrate the issue.

I didn't find any inconsistency between encode_plus([string]) and encode_plus(string) but batch_encode_plus([strings]) and batch_encode_plus([[tokens]])

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Mar 29, 2020

Sorry, I was skimming through your problem too quickly - I see what you mean now.
I will take a closer look at this.

@patrickvonplaten
Copy link
Contributor

Created a PR this fixes this behavior. Thanks for pointing this out @creat89 :-)

@creat89 creat89 closed this as completed Mar 29, 2020
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 7, 2020

There has been a big change in tokenizers recently :-) which adds a is_pretokenized flag to the input which makes everything much easier. This should then be used as follows:
bert_tokenizer.batch_encode_plus([tokens], is_pretokenized=True))

@creat89
Copy link
Author

creat89 commented Apr 7, 2020

Cool, that's awesome and yes, I'm sure that makes everything easier. Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants