Skip to content

Commit

Permalink
Allow batching for feature-extraction (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
osanseviero committed Jun 16, 2021
1 parent 9f0fd3b commit 9aa9b20
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
44 changes: 30 additions & 14 deletions api-inference-community/api_inference_community/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,30 @@ class TableQuestionAnsweringInputsCheck(BaseModel):
query: str

@validator("table")
def all_rows_must_have_same_length(
cls, table: Dict[str, List[str]]
):
def all_rows_must_have_same_length(cls, table: Dict[str, List[str]]):
rows = list(table.values())
n = len(rows[0])
if all(len(x) == n for x in rows):
return table
raise ValueError("All rows in the table must be the same length")


class StringOrStringBatchInputCheck(BaseModel):
__root__: Union[List[str], str]

@validator("__root__")
def input_must_not_be_empty(cls, __root__: Union[List[str], str]):
if isinstance(__root__, list):
if len(__root__) == 0:
raise ValueError(
"The inputs are invalid, at least one input is required"
)
return __root__


class StringInput(BaseModel):
__root__: str


PARAMS_MAPPING = {
"conversational": SharedGenerationParams,
Expand All @@ -147,10 +161,21 @@ def all_rows_must_have_same_length(
INPUTS_MAPPING = {
"conversational": ConversationalInputsCheck,
"question-answering": QuestionInputsCheck,
"feature-extraction": StringOrStringBatchInputCheck,
"sentence-similarity": SentenceSimilarityInputsCheck,
"table-question-answering": TableQuestionAnsweringInputsCheck,
"fill-mask": StringInput,
"summarization": StringInput,
"text2text-generation": StringInput,
"text-generation": StringInput,
"text-classification": StringInput,
"token-classification": StringInput,
"translation": StringInput,
"zero-shot-classification": StringInput,
}

BATCH_ENABLED_PIPELINES = ["feature-extraction"]


def check_params(params, tag):
if tag in PARAMS_MAPPING:
Expand All @@ -161,18 +186,9 @@ def check_params(params, tag):
def check_inputs(inputs, tag):
if tag in INPUTS_MAPPING:
INPUTS_MAPPING[tag].parse_obj(inputs)
return True
else:
# Some tasks just expect {inputs: "str"}. Such as:
# feature-extraction
# fill-mask
# text2text-generation
# text-classification
# text-generation
# token-classification
# translation
if not isinstance(inputs, str):
raise ValueError("The inputs is invalid, we expect a string")
return True
raise ValueError(f"{tag} is not a valid pipeline.")


def normalize_payload(
Expand Down
39 changes: 35 additions & 4 deletions api-inference-community/tests/test_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@ class ValidationTestCase(TestCase):
def test_malformed_input(self):
bpayload = b"\xc3\x28"
with self.assertRaises(UnicodeDecodeError):
normalize_payload_nlp(bpayload, "tag")
normalize_payload_nlp(bpayload, "question-answering")

def test_accept_raw_string_for_backward_compatibility(self):
query = "funny cats"
bpayload = query.encode("utf-8")
normalized_inputs, processed_params = normalize_payload_nlp(bpayload, "tag")
normalized_inputs, processed_params = normalize_payload_nlp(
bpayload, "translation"
)
self.assertEqual(processed_params, {})
self.assertEqual(normalized_inputs, query)

def test_invalid_tag(self):
query = "funny cats"
bpayload = query.encode("utf-8")
with self.assertRaises(ValueError):
normalize_payload_nlp(bpayload, "invalid-tag")


class QuestionAnsweringValidationTestCase(TestCase):
def test_valid_input(self):
Expand Down Expand Up @@ -418,15 +426,38 @@ class TextGenerationTestCase(make_text_generation_test_case("text-generation")):
pass


class TasksWithOnlyInputStringTestCase(TestCase):
def test_feature_extraction_accept_string_no_params(self):
class FeatureExtractionTestCase(TestCase):
def test_valid_string(self):
bpayload = json.dumps({"inputs": "whatever"}).encode("utf-8")
normalized_inputs, processed_params = normalize_payload_nlp(
bpayload, "feature-extraction"
)
self.assertEqual(processed_params, {})
self.assertEqual(normalized_inputs, "whatever")

def test_valid_list_of_strings(self):
inputs = ["hugging", "face"]
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
normalized_inputs, processed_params = normalize_payload_nlp(
bpayload, "feature-extraction"
)
self.assertEqual(processed_params, {})
self.assertEqual(normalized_inputs, inputs)

def test_invalid_list_with_other_type(self):
inputs = ["hugging", [1, 2, 3]]
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
with self.assertRaises(ValueError):
normalize_payload_nlp(bpayload, "feature-extraction")

def test_invalid_empty_list(self):
inputs = []
bpayload = json.dumps({"inputs": inputs}).encode("utf-8")
with self.assertRaises(ValueError):
normalize_payload_nlp(bpayload, "feature-extraction")


class TasksWithOnlyInputStringTestCase(TestCase):
def test_fill_mask_accept_string_no_params(self):
bpayload = json.dumps({"inputs": "whatever"}).encode("utf-8")
normalized_inputs, processed_params = normalize_payload_nlp(
Expand Down

0 comments on commit 9aa9b20

Please sign in to comment.