diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 4ec85e1f4c1948..e37607c136b8de 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -22,7 +22,6 @@ import uuid import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable from contextlib import contextmanager from os.path import abspath, exists from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -1598,52 +1597,55 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): command-line supplied arguments. """ - def normalize(self, item): - if isinstance(item, SquadExample): - return item - elif isinstance(item, dict): - for k in ["question", "context"]: - if k not in item: - raise KeyError("You need to provide a dictionary with keys {question:..., context:...}") - elif item[k] is None: - raise ValueError("`{}` cannot be None".format(k)) - elif isinstance(item[k], str) and len(item[k]) == 0: - raise ValueError("`{}` cannot be empty".format(k)) - - return QuestionAnsweringPipeline.create_sample(**item) - raise ValueError("{} argument needs to be of type (SquadExample, dict)".format(item)) - def __call__(self, *args, **kwargs): - # Detect where the actual inputs are + # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating if args is not None and len(args) > 0: if len(args) == 1: - inputs = args[0] - elif len(args) == 2 and {type(el) for el in args} == {str}: - inputs = [{"question": args[0], "context": args[1]}] + kwargs["X"] = args[0] else: - inputs = list(args) + kwargs["X"] = list(args) + # Generic compatibility with sklearn and Keras # Batched data - elif "X" in kwargs: - inputs = kwargs["X"] - elif "data" in kwargs: - inputs = kwargs["data"] + if "X" in kwargs or "data" in kwargs: + inputs = kwargs["X"] if "X" in kwargs else kwargs["data"] + + if isinstance(inputs, dict): + inputs = [inputs] + else: + # Copy to avoid overriding arguments + inputs = [i for i in inputs] + + for i, item in enumerate(inputs): + if isinstance(item, dict): + if any(k not in item for k in ["question", "context"]): + raise KeyError("You need to provide a dictionary with keys {question:..., context:...}") + + inputs[i] = QuestionAnsweringPipeline.create_sample(**item) + + elif not isinstance(item, SquadExample): + raise ValueError( + "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format( + "X" if "X" in kwargs else "data" + ) + ) + + # Tabular input elif "question" in kwargs and "context" in kwargs: - inputs = [{"question": kwargs["question"], "context": kwargs["context"]}] + if isinstance(kwargs["question"], str): + kwargs["question"] = [kwargs["question"]] + + if isinstance(kwargs["context"], str): + kwargs["context"] = [kwargs["context"]] + + inputs = [ + QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"]) + ] else: raise ValueError("Unknown arguments {}".format(kwargs)) - # Normalize inputs - if isinstance(inputs, dict): + if not isinstance(inputs, list): inputs = [inputs] - elif isinstance(inputs, Iterable): - # Copy to avoid overriding arguments - inputs = [i for i in inputs] - else: - raise ValueError("Invalid arguments {}".format(inputs)) - - for i, item in enumerate(inputs): - inputs[i] = self.normalize(item) return inputs diff --git a/tests/test_pipelines_question_answering.py b/tests/test_pipelines_question_answering.py index 54b306c09d88b3..3f3f6dc83a7241 100644 --- a/tests/test_pipelines_question_answering.py +++ b/tests/test_pipelines_question_answering.py @@ -1,7 +1,6 @@ import unittest -from transformers.data.processors.squad import SquadExample -from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler +from transformers.pipelines import Pipeline from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -44,116 +43,5 @@ def _test_pipeline(self, nlp: Pipeline): for key in output_keys: self.assertIn(key, result) for bad_input in invalid_inputs: - self.assertRaises(ValueError, nlp, bad_input) - self.assertRaises(ValueError, nlp, invalid_inputs) - - def test_argument_handler(self): - qa = QuestionAnsweringArgumentHandler() - - Q = "Where was HuggingFace founded ?" - C = "HuggingFace was founded in Paris" - - normalized = qa(Q, C) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa(question=Q, context=C) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa(question=Q, context=C) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa({"question": Q, "context": C}) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa([{"question": Q, "context": C}]) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa([{"question": Q, "context": C}, {"question": Q, "context": C}]) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 2) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa(X={"question": Q, "context": C}) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa(X=[{"question": Q, "context": C}]) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - normalized = qa(data={"question": Q, "context": C}) - self.assertEqual(type(normalized), list) - self.assertEqual(len(normalized), 1) - self.assertEqual({type(el) for el in normalized}, {SquadExample}) - - def test_argument_handler_error_handling(self): - qa = QuestionAnsweringArgumentHandler() - - Q = "Where was HuggingFace founded ?" - C = "HuggingFace was founded in Paris" - - with self.assertRaises(KeyError): - qa({"context": C}) - with self.assertRaises(KeyError): - qa({"question": Q}) - with self.assertRaises(KeyError): - qa([{"context": C}]) - with self.assertRaises(ValueError): - qa(None, C) - with self.assertRaises(ValueError): - qa("", C) - with self.assertRaises(ValueError): - qa(Q, None) - with self.assertRaises(ValueError): - qa(Q, "") - - with self.assertRaises(ValueError): - qa(question=None, context=C) - with self.assertRaises(ValueError): - qa(question="", context=C) - with self.assertRaises(ValueError): - qa(question=Q, context=None) - with self.assertRaises(ValueError): - qa(question=Q, context="") - - with self.assertRaises(ValueError): - qa({"question": None, "context": C}) - with self.assertRaises(ValueError): - qa({"question": "", "context": C}) - with self.assertRaises(ValueError): - qa({"question": Q, "context": None}) - with self.assertRaises(ValueError): - qa({"question": Q, "context": ""}) - - with self.assertRaises(ValueError): - qa([{"question": Q, "context": C}, {"question": None, "context": C}]) - with self.assertRaises(ValueError): - qa([{"question": Q, "context": C}, {"question": "", "context": C}]) - - with self.assertRaises(ValueError): - qa([{"question": Q, "context": C}, {"question": Q, "context": None}]) - with self.assertRaises(ValueError): - qa([{"question": Q, "context": C}, {"question": Q, "context": ""}]) - - def test_argument_handler_error_handling_odd(self): - qa = QuestionAnsweringArgumentHandler() - with self.assertRaises(ValueError): - qa(None) - - with self.assertRaises(ValueError): - qa(Y=None) - - with self.assertRaises(ValueError): - qa(1) + self.assertRaises(Exception, nlp, bad_input) + self.assertRaises(Exception, nlp, invalid_inputs)