Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Ner pipeline grouped_entities fixes #5970

Merged
merged 23 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 46 additions & 34 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,29 @@ def __call__(self, *args, targets=None, **kwargs):
return results


class TokenClassificationArgumentHandler(ArgumentHandler):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to check offset_mapping if provided. (does a simple batch_size check)

"""
Handles arguments for token classification.
"""

def __call__(self, *args, **kwargs):

if args is not None and len(args) > 0:
if isinstance(args, str):
inputs = [args]
else:
inputs = args
batch_size = len(inputs)

offset_mapping = kwargs.get("offset_mapping", None)
if offset_mapping:
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
offset_mapping = [offset_mapping]
if len(offset_mapping) != batch_size:
raise ("offset_mapping should have the same batch size as the input")
return inputs, offset_mapping


@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
Expand Down Expand Up @@ -1343,7 +1366,7 @@ def __init__(
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=args_parser,
args_parser=TokenClassificationArgumentHandler(),
device=device,
binary_output=binary_output,
task=task,
Expand Down Expand Up @@ -1379,12 +1402,11 @@ def __call__(self, *args, **kwargs):
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
"""
inputs = self._args_parser(*args, **kwargs)

inputs, offset_mappings = self._args_parser(*args, **kwargs)
answers = []

for i, sentence in enumerate(inputs):
if "offset_mapping" in kwargs:
offset_mapping = kwargs["offset_mapping"][i]

# Manage correct placement of the tensors
with self.device_placement():
Expand All @@ -1397,9 +1419,13 @@ def __call__(self, *args, **kwargs):
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if "offset_mapping" in tokens:
if self.tokenizer.is_fast:
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0]
del tokens["offset_mapping"]
elif offset_mappings:
offset_mapping = offset_mappings[i]
else:
raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter")
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0]
del tokens["special_tokens_mask"]

Expand All @@ -1426,15 +1452,14 @@ def __call__(self, *args, **kwargs):
]

for idx, label_idx in filtered_labels_idx:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed is_subword detection logic: by comparing length of token(##token) with the original text span mapping (Assuming subwordpieces get prefixed by something).
Incase the user wants some other logic they can first get ungrouped entities add is_subword:bool field to entities and call pipeline.group_entities themselves.


if int(input_ids[idx]) == self.tokenizer.unk_token_id:
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word = sentence[start_ind:end_ind]
else:
raise Exception("Use a fast tokenizer or provide offset_mapping parameter")
else:
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
word = word_ref
is_subword = False

entity = {
"word": word,
Expand All @@ -1443,6 +1468,9 @@ def __call__(self, *args, **kwargs):
"index": idx,
}

if self.grouped_entities and self.ignore_subwords:
entity["is_subword"] = is_subword

entities += [entity]

# Append grouped entities
Expand All @@ -1467,28 +1495,17 @@ def group_sub_entities(self, entities: List[dict]) -> dict:
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

if self.tokenizer.is_fast:
word = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokens))
else:
word = self.tokenizer.convert_tokens_to_string(tokens)
Copy link
Contributor Author

@cceyda cceyda Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as suggested! I agree it is much cleaner this way. Umm it looks like fast tokenizers have a convert_tokens_to_string method now? 😕

entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": self.convert_tokens_to_string(tokens),
"word": word,
}
return entity_group

def is_subword_fn(self, token: str) -> bool:
if token.startswith("##"):
return True
return False

def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
if hasattr(self.tokenizer, "convert_tokens_to_string"):
# fast tokenizers dont have convert_tokens_to_string?!
return self.tokenizer.convert_tokens_to_string(tokens)
else:
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string

def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
Expand All @@ -1500,18 +1517,13 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
entity_groups = []
entity_group_disagg = []

if hasattr(self.tokenizer, "is_subword_fn"):
is_subword_fn = self.tokenizer.is_subword_fn
else:
is_subword_fn = self.is_subword_fn

if entities:
last_idx = entities[-1]["index"]

for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and is_subword_fn(entity["word"])
is_subword = self.ignore_subwords and entity["is_subword"]
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
Expand Down
32 changes: 16 additions & 16 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,24 +718,24 @@ def _test_ner_pipeline(

ungrouped_ner_inputs = [
[
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"},
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
],
[
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
],
]

Expand Down