Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to mimic trim_batch using new tokenizer strategies? #5181

Closed
sshleifer opened this issue Jun 22, 2020 · 3 comments · Fixed by #5252
Closed

Is it possible to mimic trim_batch using new tokenizer strategies? #5181

sshleifer opened this issue Jun 22, 2020 · 3 comments · Fixed by #5252
Labels
Core: Tokenization Internals of the library; Tokenization.

Comments

@sshleifer
Copy link
Contributor

I am trying to replace the old workflow of
calling batch_encode_plus to make tensors of shape
(n_examples, model_max_length) and then calling trim_batch to reduce padding computation, with the new tokenizers kwargs.
Is this possible?
The following code does not seem to truncate inputs longer than 512 (the second assert breaks).

Attempt:

from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
kw = dict(max_length=512, 
          pad_to_max_length=True, padding=True, return_tensors='pt', truncation='only_first')
batch = tokenizer(['tiny sentence 1', 'tiny_sentence2'],**kw)
assert batch.input_ids.shape[1] == 7, batch.input_ids.shape[1]
input_ids, mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
assert input_ids.shape[1] == 7, batch.input_ids.shape[1]

batch_overflow = tokenizer(['tiny sentence 1'*1000, 'tiny_sentence2'], **kw)


assert batch_overflow.input_ids.shape[1] == 512, batch_overflow.input_ids.shape[1]

Traceback:

assert batch_overflow.input_ids.shape[1] == 512, batch_overflow.input_ids.shape[1]

AssertionError: 3002

Help much appreciated, @mfuntowicz @thomwolf

@sshleifer sshleifer added the Core: Tokenization Internals of the library; Tokenization. label Jun 22, 2020
@thomwolf
Copy link
Member

Hi @sshleifer you should read the detailed description on the tokenizers refactoring PR #4510 (comment)

Until it's added in the doc (will be soon), it's required reading for all core contributors of transformers.

@sshleifer
Copy link
Contributor Author

sshleifer commented Jun 22, 2020

Thanks. I read that, and am still somewhat confused about why I pass truncation=True and get entries that are longer than tokenizer.max_model_length. The PR description says:

image

Here is a simplified example:

from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
assert tokenizer.model_max_length == 1024

# tokenizer.batch_encode_plus returns ids shaped (2, 1024)
batch_sentences = ['tiny sentence 1'*1000, 'tiny_sentence2']
ids = tokenizer.batch_encode_plus(batch_sentences, pad_to_max_length=True, max_length=tokenizer.model_max_length,
                                  truncation=True, return_tensors='pt').input_ids
assert ids.shape[1] <= tokenizer.model_max_length, ids.shape[1]

# tokenizer.__call__ returns ids shaped (2, 3002)
ids = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors='pt',
                max_length=tokenizer.model_max_length, ).input_ids
assert ids.shape[1] <= tokenizer.model_max_length, ids.shape[1]

@thomwolf
Copy link
Member

I'll take a look

thomwolf added a commit that referenced this issue Jun 25, 2020
…efault logging level in tests to WARNING (#5252)

* fix-5181

Padding to max sequence length while truncation to another length was wrong on slow tokenizers

* clean up and fix #5155

* fix XLM test

* Fix tests for Transfo-XL

* logging only above WARNING in tests

* switch slow tokenizers tests in @slow

* fix Marian truncation tokenization test

* style and quality

* make the test a lot faster by limiting the sequence length used in tests
jplu pushed a commit to jplu/transformers that referenced this issue Jun 29, 2020
…icit - move back the default logging level in tests to WARNING (huggingface#5252)

* fix-5181

Padding to max sequence length while truncation to another length was wrong on slow tokenizers

* clean up and fix huggingface#5155

* fix XLM test

* Fix tests for Transfo-XL

* logging only above WARNING in tests

* switch slow tokenizers tests in @slow

* fix Marian truncation tokenization test

* style and quality

* make the test a lot faster by limiting the sequence length used in tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
2 participants