diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index f63071b407db2..e37607c136b8d 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1324,6 +1324,29 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): return results +class TokenClassificationArgumentHandler(ArgumentHandler): + """ + Handles arguments for token classification. + """ + + def __call__(self, *args, **kwargs): + + if args is not None and len(args) > 0: + if isinstance(args, str): + inputs = [args] + else: + inputs = args + batch_size = len(inputs) + + offset_mapping = kwargs.get("offset_mapping", None) + if offset_mapping: + if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): + offset_mapping = [offset_mapping] + if len(offset_mapping) != batch_size: + raise ("offset_mapping should have the same batch size as the input") + return inputs, offset_mapping + + @add_end_docstrings( PIPELINE_INIT_ARGS, r""" @@ -1361,13 +1384,14 @@ def __init__( ignore_labels=["O"], task: str = "", grouped_entities: bool = False, + ignore_subwords: bool = True, ): super().__init__( model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, - args_parser=args_parser, + args_parser=TokenClassificationArgumentHandler(), device=device, binary_output=binary_output, task=task, @@ -1382,6 +1406,7 @@ def __init__( self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self.ignore_labels = ignore_labels self.grouped_entities = grouped_entities + self.ignore_subwords = ignore_subwords def __call__(self, inputs: Union[str, List[str]], **kwargs): """ @@ -1402,10 +1427,15 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the corresponding token in the sentence. """ + if isinstance(inputs, str): inputs = [inputs] + + offset_mappings = kwargs.get("offset_mappings") + answers = [] - for sentence in inputs: + + for i, sentence in enumerate(inputs): # Manage correct placement of the tensors with self.device_placement(): @@ -1415,7 +1445,18 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): return_attention_mask=False, return_tensors=self.framework, truncation=True, + return_special_tokens_mask=True, + return_offsets_mapping=self.tokenizer.is_fast, ) + if self.tokenizer.is_fast: + offset_mapping = tokens["offset_mapping"].cpu().numpy()[0] + del tokens["offset_mapping"] + elif offset_mappings: + offset_mapping = offset_mappings[i] + else: + raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter") + special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0] + del tokens["special_tokens_mask"] # Forward if self.framework == "tf": @@ -1432,24 +1473,35 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): entities = [] # Filter to labels not in `self.ignore_labels` + # Filter special_tokens filtered_labels_idx = [ (idx, label_idx) for idx, label_idx in enumerate(labels_idx) - if self.model.config.id2label[label_idx] not in self.ignore_labels + if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx] ] for idx, label_idx in filtered_labels_idx: + start_ind, end_ind = offset_mapping[idx] + word_ref = sentence[start_ind:end_ind] + word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] + is_subword = len(word_ref) != len(word) + + if int(input_ids[idx]) == self.tokenizer.unk_token_id: + word = word_ref + is_subword = False entity = { - "word": self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])), + "word": word, "score": score[idx][label_idx].item(), "entity": self.model.config.id2label[label_idx], "index": idx, } + if self.grouped_entities and self.ignore_subwords: + entity["is_subword"] = is_subword + entities += [entity] - # Append grouped entities if self.grouped_entities: answers += [self.group_entities(entities)] # Append ungrouped entities @@ -1468,8 +1520,8 @@ def group_sub_entities(self, entities: List[dict]) -> dict: entities (:obj:`dict`): The entities predicted by the pipeline. """ # Get the first entity in the entity group - entity = entities[0]["entity"] - scores = np.mean([entity["score"] for entity in entities]) + entity = entities[0]["entity"].split("-")[-1] + scores = np.nanmean([entity["score"] for entity in entities]) tokens = [entity["word"] for entity in entities] entity_group = { @@ -1494,7 +1546,9 @@ def group_entities(self, entities: List[dict]) -> List[dict]: last_idx = entities[-1]["index"] for entity in entities: + is_last_idx = entity["index"] == last_idx + is_subword = self.ignore_subwords and entity["is_subword"] if not entity_group_disagg: entity_group_disagg += [entity] if is_last_idx: @@ -1503,10 +1557,19 @@ def group_entities(self, entities: List[dict]) -> List[dict]: # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group # The split is meant to account for the "B" and "I" suffixes + # Shouldn't merge if both entities are B-type if ( - entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] + ( + entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] + and entity["entity"].split("-")[0] != "B" + ) and entity["index"] == entity_group_disagg[-1]["index"] + 1 - ): + ) or is_subword: + # Modify subword type to be previous_type + if is_subword: + entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1] + entity["score"] = np.nan # set ignored scores to nan and use np.nanmean + entity_group_disagg += [entity] # Group the entities at the last entity if is_last_idx: diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index 4fb58d5e3c0c9..a4a240a8d9081 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -1,8 +1,8 @@ import unittest -from transformers import pipeline +from transformers import AutoTokenizer, pipeline from transformers.pipelines import Pipeline -from transformers.testing_utils import require_tf +from transformers.testing_utils import require_tf, require_torch from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -19,38 +19,54 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): def _test_pipeline(self, nlp: Pipeline): output_keys = {"entity", "word", "score"} + if nlp.grouped_entities: + output_keys = {"entity_group", "word", "score"} ungrouped_ner_inputs = [ [ - {"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"}, - {"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"}, - {"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"}, - {"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"}, - {"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"}, - {"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"}, - {"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"}, - {"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"}, - {"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"}, - {"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"}, - {"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"}, - {"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"}, - {"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"}, + {"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"}, + {"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"}, + {"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"}, + {"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"}, + {"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"}, + {"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"}, + {"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"}, + {"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"}, + {"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"}, + {"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"}, + {"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"}, + {"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"}, + {"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"}, ], [ - {"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"}, - {"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"}, - {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"}, + {"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"}, + {"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"}, + {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"}, ], ] + expected_grouped_ner_results = [ [ - {"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Araújo Noguera"}, - {"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, - {"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"}, + {"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"}, + {"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"}, + {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"}, + ], + [ + {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"}, + {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, + ], + ] + + expected_grouped_ner_results_w_subword = [ + [ + {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"}, + {"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"}, + {"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, + {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"}, ], [ - {"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"}, - {"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"}, + {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"}, + {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, ], ] @@ -77,12 +93,80 @@ def _test_pipeline(self, nlp: Pipeline): for key in output_keys: self.assertIn(key, result) - for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): - self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) + if nlp.grouped_entities: + if nlp.ignore_subwords: + for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): + self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) + else: + for ungrouped_input, grouped_result in zip( + ungrouped_ner_inputs, expected_grouped_ner_results_w_subword + ): + self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) @require_tf def test_tf_only(self): model_name = "Narsil/small" # This model only has a TensorFlow version # We test that if we don't specificy framework='tf', it gets detected automatically - nlp = pipeline(task="ner", model=model_name, tokenizer=model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) self._test_pipeline(nlp) + + # offset=tokenizer(VALID_INPUTS[0],return_offsets_mapping=True)['offset_mapping'] + # pipeline_running_kwargs = {"offset_mapping"} # Additional kwargs to run the pipeline with + + @require_tf + def test_tf_defaults(self): + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="tf") + self._test_pipeline(nlp) + + @require_tf + def test_tf_small(self): + for model_name in self.small_models: + print(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline( + task="ner", + model=model_name, + tokenizer=tokenizer, + framework="tf", + grouped_entities=True, + ignore_subwords=True, + ) + self._test_pipeline(nlp) + + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline( + task="ner", + model=model_name, + tokenizer=tokenizer, + framework="tf", + grouped_entities=True, + ignore_subwords=False, + ) + self._test_pipeline(nlp) + + @require_torch + def test_pt_defaults(self): + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) + self._test_pipeline(nlp) + + @require_torch + def test_torch_small(self): + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline( + task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True + ) + self._test_pipeline(nlp) + + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline( + task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False + ) + self._test_pipeline(nlp)