diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index e2e18780177bf..fd441baf4af6c 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -88,7 +88,7 @@ def _parse_and_tokenize( hypothesis_template, padding=True, add_special_tokens=True, - truncation=TruncationStrategy.DO_NOT_TRUNCATE, + truncation=TruncationStrategy.ONLY_FIRST, **kwargs ): """ @@ -113,13 +113,31 @@ def _parse_and_tokenize( ) inputs.append(model_input) else: - inputs = self.tokenizer( - sequence_pairs, - add_special_tokens=add_special_tokens, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - ) + try: + inputs = self.tokenizer( + sequence_pairs, + add_special_tokens=add_special_tokens, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + ) + except Exception as e: + if "too short" in str(e): + # tokenizers might yell that we want to truncate + # to a value that is not even reached by the input. + # In that case we don't want to truncate. + # It seems there's not a really better way to catch that + # exception. + + inputs = self.tokenizer( + sequence_pairs, + add_special_tokens=add_special_tokens, + return_tensors=return_tensors, + padding=padding, + truncation=TruncationStrategy.DO_NOT_TRUNCATE, + ) + else: + raise e return inputs diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index 5d0df573964f0..69fd65f71dd3d 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -105,6 +105,20 @@ def run_entailment_id(self, zero_shot_classifier: Pipeline): zero_shot_classifier.model.config.label2id = original_label2id self.assertEqual(original_entailment, zero_shot_classifier.entailment_id) + @require_torch + def test_truncation(self): + zero_shot_classifier = pipeline( + "zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + framework="pt", + ) + # There was a regression in 4.10 for this + # Adding a test so we don't make the mistake again. + # https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499 + zero_shot_classifier( + "Who are you voting for in 2020?" * 100, candidate_labels=["politics", "public health", "science"] + ) + @require_torch def test_small_model_pt(self): zero_shot_classifier = pipeline(