From a02ca81d94fa087187a686643be9d4f7f385875c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 2 Nov 2020 12:33:50 +0100 Subject: [PATCH] Fix the behaviour of DefaultArgumentHandler (removing it). (#8180) * Some work to fix the behaviour of DefaultArgumentHandler by removing it. * Fixing specific pipelines argument checking. --- src/transformers/pipelines.py | 100 ++++++-------------- tests/test_pipelines_common.py | 147 +++++++++++++++--------------- tests/test_pipelines_fill_mask.py | 28 +++++- tests/test_pipelines_zero_shot.py | 2 +- 4 files changed, 129 insertions(+), 148 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index a8f9fe8ae7609b..f9897ef2dbd3c5 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -23,9 +23,8 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager -from itertools import chain from os.path import abspath, exists -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from uuid import UUID import numpy as np @@ -185,57 +184,6 @@ def __call__(self, *args, **kwargs): raise NotImplementedError() -class DefaultArgumentHandler(ArgumentHandler): - """ - Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`. - """ - - @staticmethod - def handle_kwargs(kwargs: Dict) -> List: - if len(kwargs) == 1: - output = list(kwargs.values()) - else: - output = list(chain(kwargs.values())) - - return DefaultArgumentHandler.handle_args(output) - - @staticmethod - def handle_args(args: Sequence[Any]) -> List[str]: - - # Only one argument, let's do case by case - if len(args) == 1: - if isinstance(args[0], str): - return [args[0]] - elif not isinstance(args[0], list): - return list(args) - else: - return args[0] - - # Multiple arguments (x1, x2, ...) - elif len(args) > 1: - if all([isinstance(arg, str) for arg in args]): - return list(args) - - # If not instance of list, then it should instance of iterable - elif isinstance(args, Iterable): - return list(chain.from_iterable(chain(args))) - else: - raise ValueError( - "Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args)) - ) - else: - return [] - - def __call__(self, *args, **kwargs): - if len(kwargs) > 0 and len(args) > 0: - raise ValueError("Pipeline cannot handle mixed args and kwargs") - - if len(kwargs) > 0: - return DefaultArgumentHandler.handle_kwargs(kwargs) - else: - return DefaultArgumentHandler.handle_args(args) - - class PipelineDataFormat: """ Base class for all the pipeline supported data format both for reading and writing. Supported data formats @@ -574,7 +522,6 @@ def __init__( self.framework = framework self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device)) self.binary_output = binary_output - self._args_parser = args_parser or DefaultArgumentHandler() # Special handling if self.framework == "pt" and self.device.type == "cuda": @@ -669,12 +616,11 @@ def check_model_type(self, supported_models: Union[List[str], dict]): f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", ) - def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): """ Parse arguments and tokenize """ # Parse arguments - inputs = self._args_parser(*args, **kwargs) inputs = self.tokenizer( inputs, add_special_tokens=add_special_tokens, @@ -836,7 +782,7 @@ def __init__(self, *args, **kwargs): # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments - def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): """ Parse arguments and tokenize """ @@ -845,7 +791,6 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw tokenizer_kwargs = {"add_space_before_punct_symbol": True} else: tokenizer_kwargs = {} - inputs = self._args_parser(*args, **kwargs) inputs = self.tokenizer( inputs, add_special_tokens=add_special_tokens, @@ -858,7 +803,7 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw def __call__( self, - *args, + text_inputs, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, @@ -890,7 +835,6 @@ def __call__( - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) -- The token ids of the generated text. """ - text_inputs = self._args_parser(*args) results = [] for prompt_text in text_inputs: @@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline): """ def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs): - super().__init__(*args, args_parser=args_parser, **kwargs) + super().__init__(*args, **kwargs) + self._args_parser = args_parser if self.entailment_id == -1: logger.warning( "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to " @@ -1108,13 +1053,15 @@ def entailment_id(self): return ind return -1 - def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize( + self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs + ): """ Parse arguments and tokenize only_first so that hypothesis (label) is not truncated """ - inputs = self._args_parser(*args, **kwargs) + sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template) inputs = self.tokenizer( - inputs, + sequence_pairs, add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding, @@ -1123,7 +1070,13 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw return inputs - def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False): + def __call__( + self, + sequences: Union[str, List[str]], + candidate_labels, + hypothesis_template="This example is {}.", + multi_class=False, + ): """ Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline` documentation for more information. @@ -1154,8 +1107,11 @@ def __call__(self, sequences, candidate_labels, hypothesis_template="This exampl - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood. - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. """ + if sequences and isinstance(sequences, str): + sequences = [sequences] + outputs = super().__call__(sequences, candidate_labels, hypothesis_template) - num_sequences = 1 if isinstance(sequences, str) else len(sequences) + num_sequences = len(sequences) candidate_labels = self._args_parser._parse_labels(candidate_labels) reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1)) @@ -1425,12 +1381,12 @@ def __init__( self.ignore_labels = ignore_labels self.grouped_entities = grouped_entities - def __call__(self, *args, **kwargs): + def __call__(self, inputs: Union[str, List[str]], **kwargs): """ Classify each token of the text(s) given as inputs. Args: - args (:obj:`str` or :obj:`List[str]`): + inputs (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of texts) for token classification. Return: @@ -1444,7 +1400,8 @@ def __call__(self, *args, **kwargs): - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the corresponding token in the sentence. """ - inputs = self._args_parser(*args, **kwargs) + if isinstance(inputs, str): + inputs = [inputs] answers = [] for sentence in inputs: @@ -1659,12 +1616,12 @@ def __init__( tokenizer=tokenizer, modelcard=modelcard, framework=framework, - args_parser=QuestionAnsweringArgumentHandler(), device=device, task=task, **kwargs, ) + self._args_parser = QuestionAnsweringArgumentHandler() self.check_model_type( TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING ) @@ -2489,12 +2446,11 @@ def __call__( else: return output - def _parse_and_tokenize(self, *args, **kwargs): + def _parse_and_tokenize(self, inputs, **kwargs): """ Parse arguments and tokenize, adding an EOS token at the end of the user input """ # Parse arguments - inputs = self._args_parser(*args, **kwargs) inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", []) for input in inputs: input.append(self.tokenizer.eos_token_id) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index d6acea2da6cc14..697df13a17819f 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -1,8 +1,9 @@ -import unittest from typing import List, Optional from transformers import is_tf_available, is_torch_available, pipeline -from transformers.pipelines import DefaultArgumentHandler, Pipeline + +# from transformers.pipelines import DefaultArgumentHandler, Pipeline +from transformers.pipelines import Pipeline from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow @@ -200,74 +201,74 @@ def _test_pipeline(self, nlp: Pipeline): self.assertRaises(Exception, nlp, self.invalid_inputs) -@is_pipeline_test -class DefaultArgumentHandlerTestCase(unittest.TestCase): - def setUp(self) -> None: - self.handler = DefaultArgumentHandler() - - def test_kwargs_x(self): - mono_data = {"X": "This is a sample input"} - mono_args = self.handler(**mono_data) - - self.assertTrue(isinstance(mono_args, list)) - self.assertEqual(len(mono_args), 1) - - multi_data = {"x": ["This is a sample input", "This is a second sample input"]} - multi_args = self.handler(**multi_data) - - self.assertTrue(isinstance(multi_args, list)) - self.assertEqual(len(multi_args), 2) - - def test_kwargs_data(self): - mono_data = {"data": "This is a sample input"} - mono_args = self.handler(**mono_data) - - self.assertTrue(isinstance(mono_args, list)) - self.assertEqual(len(mono_args), 1) - - multi_data = {"data": ["This is a sample input", "This is a second sample input"]} - multi_args = self.handler(**multi_data) - - self.assertTrue(isinstance(multi_args, list)) - self.assertEqual(len(multi_args), 2) - - def test_multi_kwargs(self): - mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"} - mono_args = self.handler(**mono_data) - - self.assertTrue(isinstance(mono_args, list)) - self.assertEqual(len(mono_args), 2) - - multi_data = { - "data": ["This is a sample input", "This is a second sample input"], - "test": ["This is a sample input 2", "This is a second sample input 2"], - } - multi_args = self.handler(**multi_data) - - self.assertTrue(isinstance(multi_args, list)) - self.assertEqual(len(multi_args), 4) - - def test_args(self): - mono_data = "This is a sample input" - mono_args = self.handler(mono_data) - - self.assertTrue(isinstance(mono_args, list)) - self.assertEqual(len(mono_args), 1) - - mono_data = ["This is a sample input"] - mono_args = self.handler(mono_data) - - self.assertTrue(isinstance(mono_args, list)) - self.assertEqual(len(mono_args), 1) - - multi_data = ["This is a sample input", "This is a second sample input"] - multi_args = self.handler(multi_data) - - self.assertTrue(isinstance(multi_args, list)) - self.assertEqual(len(multi_args), 2) - - multi_data = ["This is a sample input", "This is a second sample input"] - multi_args = self.handler(*multi_data) - - self.assertTrue(isinstance(multi_args, list)) - self.assertEqual(len(multi_args), 2) +# @is_pipeline_test +# class DefaultArgumentHandlerTestCase(unittest.TestCase): +# def setUp(self) -> None: +# self.handler = DefaultArgumentHandler() +# +# def test_kwargs_x(self): +# mono_data = {"X": "This is a sample input"} +# mono_args = self.handler(**mono_data) +# +# self.assertTrue(isinstance(mono_args, list)) +# self.assertEqual(len(mono_args), 1) +# +# multi_data = {"x": ["This is a sample input", "This is a second sample input"]} +# multi_args = self.handler(**multi_data) +# +# self.assertTrue(isinstance(multi_args, list)) +# self.assertEqual(len(multi_args), 2) +# +# def test_kwargs_data(self): +# mono_data = {"data": "This is a sample input"} +# mono_args = self.handler(**mono_data) +# +# self.assertTrue(isinstance(mono_args, list)) +# self.assertEqual(len(mono_args), 1) +# +# multi_data = {"data": ["This is a sample input", "This is a second sample input"]} +# multi_args = self.handler(**multi_data) +# +# self.assertTrue(isinstance(multi_args, list)) +# self.assertEqual(len(multi_args), 2) +# +# def test_multi_kwargs(self): +# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"} +# mono_args = self.handler(**mono_data) +# +# self.assertTrue(isinstance(mono_args, list)) +# self.assertEqual(len(mono_args), 2) +# +# multi_data = { +# "data": ["This is a sample input", "This is a second sample input"], +# "test": ["This is a sample input 2", "This is a second sample input 2"], +# } +# multi_args = self.handler(**multi_data) +# +# self.assertTrue(isinstance(multi_args, list)) +# self.assertEqual(len(multi_args), 4) +# +# def test_args(self): +# mono_data = "This is a sample input" +# mono_args = self.handler(mono_data) +# +# self.assertTrue(isinstance(mono_args, list)) +# self.assertEqual(len(mono_args), 1) +# +# mono_data = ["This is a sample input"] +# mono_args = self.handler(mono_data) +# +# self.assertTrue(isinstance(mono_args, list)) +# self.assertEqual(len(mono_args), 1) +# +# multi_data = ["This is a sample input", "This is a second sample input"] +# multi_args = self.handler(multi_data) +# +# self.assertTrue(isinstance(multi_args, list)) +# self.assertEqual(len(multi_args), 2) +# +# multi_data = ["This is a sample input", "This is a second sample input"] +# multi_args = self.handler(*multi_data) +# +# self.assertTrue(isinstance(multi_args, list)) +# self.assertEqual(len(multi_args), 2) diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index e49454d7416834..16404e8fd76346 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -1,5 +1,7 @@ import unittest +import pytest + from transformers import pipeline from transformers.testing_utils import require_tf, require_torch, slow @@ -37,7 +39,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "fill-mask" - pipeline_loading_kwargs = {"topk": 2} + pipeline_loading_kwargs = {"top_k": 2} small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator large_models = ["distilroberta-base"] # Models tested with the @slow decorator mandatory_keys = {"sequence", "score", "token"} @@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ] expected_check_keys = ["sequence"] + @require_torch + def test_torch_topk_deprecation(self): + # At pipeline initialization only it was not enabled at pipeline + # call site before + with pytest.warns(FutureWarning, match=r".*use `top_k`.*"): + pipeline(task="fill-mask", model=self.small_models[0], topk=1) + + @require_torch + def test_torch_fill_mask(self): + valid_inputs = "My name is " + nlp = pipeline(task="fill-mask", model=self.small_models[0]) + outputs = nlp(valid_inputs) + self.assertIsInstance(outputs, list) + + # This passes + outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"]) + self.assertIsInstance(outputs, list) + + # This used to fail with `cannot mix args and kwargs` + outputs = nlp(valid_inputs, something=False) + self.assertIsInstance(outputs, list) + @require_torch def test_torch_fill_mask_with_targets(self): valid_inputs = ["My name is "] @@ -94,7 +118,7 @@ def test_torch_fill_mask_results(self): model=model_name, tokenizer=model_name, framework="pt", - topk=2, + top_k=2, ) mono_result = nlp(valid_inputs[0], targets=valid_targets) diff --git a/tests/test_pipelines_zero_shot.py b/tests/test_pipelines_zero_shot.py index a2d6c590debb57..39bc2dc124ccde 100644 --- a/tests/test_pipelines_zero_shot.py +++ b/tests/test_pipelines_zero_shot.py @@ -17,7 +17,7 @@ def _test_scores_sum_to_one(self, result): sum = 0.0 for score in result["scores"]: sum += score - self.assertAlmostEqual(sum, 1.0) + self.assertAlmostEqual(sum, 1.0, places=5) def _test_entailment_id(self, nlp: Pipeline): config = nlp.model.config