Skip to content

Commit

Permalink
Fixing NER pipeline for list inputs. (#10184)
Browse files Browse the repository at this point in the history
Fixes #10168
  • Loading branch information
Narsil committed Feb 15, 2021
1 parent 587197d commit 900daec
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
13 changes: 8 additions & 5 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification.
"""

def __call__(self, *args, **kwargs):
def __call__(self, inputs: Union[str, List[str]], **kwargs):

if args is not None and len(args) > 0:
inputs = list(args)
if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(inputs)
batch_size = len(inputs)
elif isinstance(inputs, str):
inputs = [inputs]
batch_size = 1
else:
raise ValueError("At least one input is required.")

Expand Down Expand Up @@ -137,11 +140,11 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
Only exists if the offsets are available within the tokenizer
"""

inputs, offset_mappings = self._args_parser(inputs, **kwargs)
_inputs, offset_mappings = self._args_parser(inputs, **kwargs)

answers = []

for i, sentence in enumerate(inputs):
for i, sentence in enumerate(_inputs):

# Manage correct placement of the tensors
with self.device_placement():
Expand Down
67 changes: 52 additions & 15 deletions tests/test_pipelines_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import unittest

from transformers import AutoTokenizer, pipeline
from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch, slow

from .test_pipelines_common import CustomInputPipelineCommonMixin


VALID_INPUTS = ["A simple string", ["list of strings"]]
if is_torch_available():
import numpy as np

VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]


class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
Expand Down Expand Up @@ -334,17 +337,26 @@ def test_pt_defaults(self):
@require_torch
def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2 = "This is a simple test"
output = nlp(sentence)

def simplify(output):
for i in range(len(output)):
output[i]["score"] = round(output[i]["score"], 3)
return output
if isinstance(output, (list, tuple)):
return [simplify(item) for item in output]
elif isinstance(output, dict):
return {simplify(k): simplify(v) for k, v in output.items()}
elif isinstance(output, (str, int, np.int64)):
return output
elif isinstance(output, float):
return round(output, 3)
else:
raise Exception(f"Cannot handle {type(output)}")

output = simplify(output)
output_ = simplify(output)

self.assertEqual(
output,
output_,
[
{
"entity_group": "PER",
Expand All @@ -358,6 +370,21 @@ def simplify(output):
],
)

output = nlp([sentence, sentence2])
output_ = simplify(output)

self.assertEqual(
output_,
[
[
{"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26},
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
],
[],
],
)

@require_torch
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models:
Expand Down Expand Up @@ -386,33 +413,43 @@ def test_simple(self):
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, None)

inputs, offset_mapping = self.args_parser(string, string)
inputs, offset_mapping = self.args_parser([string, string])
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, None)

inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])

inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
inputs, offset_mapping = self.args_parser(
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
)
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])

def test_errors(self):
string = "This is a simple input"

# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
# 2 sentences, 1 offset_mapping, args
with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])

# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
# 2 sentences, 1 offset_mapping, args
with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])

# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]])

# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)])

# 1 sentences, 2 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])

# 0 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])

0 comments on commit 900daec

Please sign in to comment.