Skip to content

Commit

Permalink
Fix the behaviour of DefaultArgumentHandler (removing it). (huggingfa…
Browse files Browse the repository at this point in the history
…ce#8180)

* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
  • Loading branch information
Narsil authored and fabiocapsouza committed Nov 15, 2020
1 parent 648425c commit a02ca81
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 148 deletions.
100 changes: 28 additions & 72 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from uuid import UUID

import numpy as np
Expand Down Expand Up @@ -185,57 +184,6 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError()


class DefaultArgumentHandler(ArgumentHandler):
"""
Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
"""

@staticmethod
def handle_kwargs(kwargs: Dict) -> List:
if len(kwargs) == 1:
output = list(kwargs.values())
else:
output = list(chain(kwargs.values()))

return DefaultArgumentHandler.handle_args(output)

@staticmethod
def handle_args(args: Sequence[Any]) -> List[str]:

# Only one argument, let's do case by case
if len(args) == 1:
if isinstance(args[0], str):
return [args[0]]
elif not isinstance(args[0], list):
return list(args)
else:
return args[0]

# Multiple arguments (x1, x2, ...)
elif len(args) > 1:
if all([isinstance(arg, str) for arg in args]):
return list(args)

# If not instance of list, then it should instance of iterable
elif isinstance(args, Iterable):
return list(chain.from_iterable(chain(args)))
else:
raise ValueError(
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
)
else:
return []

def __call__(self, *args, **kwargs):
if len(kwargs) > 0 and len(args) > 0:
raise ValueError("Pipeline cannot handle mixed args and kwargs")

if len(kwargs) > 0:
return DefaultArgumentHandler.handle_kwargs(kwargs)
else:
return DefaultArgumentHandler.handle_args(args)


class PipelineDataFormat:
"""
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
Expand Down Expand Up @@ -574,7 +522,6 @@ def __init__(
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler()

# Special handling
if self.framework == "pt" and self.device.type == "cuda":
Expand Down Expand Up @@ -669,12 +616,11 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)

def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
Expand Down Expand Up @@ -836,7 +782,7 @@ def __init__(self, *args, **kwargs):

# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
Expand All @@ -845,7 +791,6 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
Expand All @@ -858,7 +803,7 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw

def __call__(
self,
*args,
text_inputs,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
Expand Down Expand Up @@ -890,7 +835,6 @@ def __call__(
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text.
"""
text_inputs = self._args_parser(*args)

results = []
for prompt_text in text_inputs:
Expand Down Expand Up @@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline):
"""

def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, args_parser=args_parser, **kwargs)
super().__init__(*args, **kwargs)
self._args_parser = args_parser
if self.entailment_id == -1:
logger.warning(
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
Expand All @@ -1108,13 +1053,15 @@ def entailment_id(self):
return ind
return -1

def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(
self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
inputs = self._args_parser(*args, **kwargs)
sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
inputs = self.tokenizer(
inputs,
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
Expand All @@ -1123,7 +1070,13 @@ def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kw

return inputs

def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
def __call__(
self,
sequences: Union[str, List[str]],
candidate_labels,
hypothesis_template="This example is {}.",
multi_class=False,
):
"""
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
documentation for more information.
Expand Down Expand Up @@ -1154,8 +1107,11 @@ def __call__(self, sequences, candidate_labels, hypothesis_template="This exampl
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
"""
if sequences and isinstance(sequences, str):
sequences = [sequences]

outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
num_sequences = 1 if isinstance(sequences, str) else len(sequences)
num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))

Expand Down Expand Up @@ -1425,12 +1381,12 @@ def __init__(
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities

def __call__(self, *args, **kwargs):
def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
Classify each token of the text(s) given as inputs.
Args:
args (:obj:`str` or :obj:`List[str]`):
inputs (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of texts) for token classification.
Return:
Expand All @@ -1444,7 +1400,8 @@ def __call__(self, *args, **kwargs):
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
"""
inputs = self._args_parser(*args, **kwargs)
if isinstance(inputs, str):
inputs = [inputs]
answers = []
for sentence in inputs:

Expand Down Expand Up @@ -1659,12 +1616,12 @@ def __init__(
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device,
task=task,
**kwargs,
)

self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
)
Expand Down Expand Up @@ -2489,12 +2446,11 @@ def __call__(
else:
return output

def _parse_and_tokenize(self, *args, **kwargs):
def _parse_and_tokenize(self, inputs, **kwargs):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
Expand Down
147 changes: 74 additions & 73 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest
from typing import List, Optional

from transformers import is_tf_available, is_torch_available, pipeline
from transformers.pipelines import DefaultArgumentHandler, Pipeline

# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow


Expand Down Expand Up @@ -200,74 +201,74 @@ def _test_pipeline(self, nlp: Pipeline):
self.assertRaises(Exception, nlp, self.invalid_inputs)


@is_pipeline_test
class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.handler = DefaultArgumentHandler()

def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data)

self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)

multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)

self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)

def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data)

self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)

multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)

self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)

def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data)

self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2)

multi_data = {
"data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"],
}
multi_args = self.handler(**multi_data)

self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4)

def test_args(self):
mono_data = "This is a sample input"
mono_args = self.handler(mono_data)

self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)

mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data)

self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)

multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data)

self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)

multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data)

self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)

0 comments on commit a02ca81

Please sign in to comment.