diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 62e3abf37ecd..2f966337e669 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -149,7 +149,7 @@ def inner(items): _padding_value = t_padding_value elif key in {"input_values", "pixel_values", "input_features"}: _padding_value = f_padding_value - elif key in {"p_mask"}: + elif key in {"p_mask", "special_tokens_mask"}: _padding_value = 1 elif key in {"attention_mask", "token_type_ids"}: _padding_value = 0 diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index d14616a9aaab..56fe453dfb21 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -192,7 +192,6 @@ def preprocess(self, sentence, offset_mapping=None): truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False model_inputs = self.tokenizer( sentence, - return_attention_mask=False, return_tensors=self.framework, truncation=truncation, return_special_tokens_mask=True, diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 94ac7a19ce2f..26cfa0d3be34 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -649,6 +649,23 @@ def test_small_model_pt(self): ], ) + # Batch size does not affect outputs (attention_mask are required) + sentences = ["This is a test !", "Another test this is with longer sentence"] + outputs = token_classifier(sentences) + outputs_batched = token_classifier(sentences, batch_size=2) + # Batching does not make a difference in predictions + self.assertEqual(nested_simplify(outputs_batched), nested_simplify(outputs)) + self.assertEqual( + nested_simplify(outputs_batched), + [ + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + [], + ], + ) + @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"