Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Unexpected overflowing_tokens in tokenizer.encode_plus #8028

Closed
wangxinyu0922 opened this issue Oct 25, 2020 · 4 comments
Closed

[BUG] Unexpected overflowing_tokens in tokenizer.encode_plus #8028

wangxinyu0922 opened this issue Oct 25, 2020 · 4 comments
Labels

Comments

@wangxinyu0922
Copy link

wangxinyu0922 commented Oct 25, 2020

Environment info

  • transformers version: 3.4.0
  • Platform: Linux
  • Python version: 3.7
  • PyTorch version (GPU?): 1.3.1
  • Tensorflow version (GPU?):
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?:

Who can help

tokenizers: @mfuntowicz

Information

When I am using BERT tokenizer, I get unexpected overflowing_tokens. Here is a example code to reproduce:

To reproduce

import torch
import transformers
from transformers import AutoTokenizer
import pdb

tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

subtoken_ids_sentence = [x for x in range(1000,1050)]

nr_sentence_parts += 1
encoded_inputs = tokenizer.encode_plus(subtoken_ids_sentence,
                                            max_length=40,
                                            stride=20,
                                            return_overflowing_tokens=True,
                                            truncation=True,
                                            )

print(encoded_inputs['overflowing_tokens'])

The output is: [1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1048, 1047, 1046, 1045, 1044, 1043, 1042, 1041, 1040, 1039, 1038]

Expected behavior

The expected behavior I want is:
[1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049]
The current output contains [1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049] and an additional reversed Tensor of [1048, 1047, 1046, 1045, 1044, 1043, 1042, 1041, 1040, 1039, 1038], which I think is wrong.

When I dig into the code, I find that:

if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
for _ in range(num_tokens_to_remove):
if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
ids = ids[:-1]

I wonder why there is a for loop in it and I think I need truncation_strategy = TruncationStrategy.ONLY_FIRST. However, I failed to turn the truncation_stractegy to only_first because the code here turn the truncation strategy to longest_first.

if max_length is not None and padding is False and truncation is False:
if verbose:
logger.warning(
"Truncation was not explicitely activated but `max_length` is provided a specific value, "
"please use `truncation=True` to explicitely truncate examples to max length. "
"Defaulting to 'longest_first' truncation strategy. "
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
"more precisely by providing a specific strategy to `truncation`."
)
truncation = "longest_first"

Can you give me any help?

@djstrong
Copy link

I confirm the issue. It was ok with transformers 3.0.0, but from 3.1.0 it is changed.

@djstrong
Copy link

And the code:

if pair_ids is None or len(ids) > len(pair_ids):
if not overflowing_tokens:
window_len = min(len(ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(ids[-window_len:])
ids = ids[:-1]
else:
if not overflowing_tokens:
window_len = min(len(pair_ids), stride + 1)
else:
window_len = 1
overflowing_tokens.extend(pair_ids[-window_len:])
pair_ids = pair_ids[:-1]

looks bugged, despite above: ids = ids[:-1] should be ids = ids[:-window_len].

djstrong referenced this issue Oct 25, 2020
* Exposing prepare_for_model for both slow & fast tokenizers

* Update method signature

* The traditional style commit

* Hide the warnings behind the verbose flag

* update default truncation strategy and prepare_for_model

* fix tests and prepare_for_models methods

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
@LysandreJik
Copy link
Member

Pinging @thomwolf

@stale
Copy link

stale bot commented Dec 25, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants