-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Ner pipeline grouped_entities fixes #5970
Changes from 2 commits
85d7554
31176c0
590ed80
56860f7
22d21cb
47a5e21
77f93e1
87c327e
456451a
99f7aad
188fc0b
b8d4b99
bd1c9bb
ba6dacb
9221ca6
2585ea2
47797d1
92115ee
0cf0e73
8e77d26
4b3d8eb
3bc55e4
70a4dc5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1299,6 +1299,29 @@ def __call__(self, *args, targets=None, **kwargs): | |
return results | ||
|
||
|
||
class TokenClassificationArgumentHandler(ArgumentHandler): | ||
""" | ||
Handles arguments for token classification. | ||
""" | ||
|
||
def __call__(self, *args, **kwargs): | ||
|
||
if args is not None and len(args) > 0: | ||
if isinstance(args, str): | ||
inputs = [args] | ||
else: | ||
inputs = args | ||
batch_size = len(inputs) | ||
|
||
offset_mapping = kwargs.get("offset_mapping", None) | ||
if offset_mapping: | ||
if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple): | ||
offset_mapping = [offset_mapping] | ||
if len(offset_mapping) != batch_size: | ||
raise ("offset_mapping should have the same batch size as the input") | ||
return inputs, offset_mapping | ||
|
||
|
||
@add_end_docstrings( | ||
PIPELINE_INIT_ARGS, | ||
r""" | ||
|
@@ -1343,7 +1366,7 @@ def __init__( | |
tokenizer=tokenizer, | ||
modelcard=modelcard, | ||
framework=framework, | ||
args_parser=args_parser, | ||
args_parser=TokenClassificationArgumentHandler(), | ||
device=device, | ||
binary_output=binary_output, | ||
task=task, | ||
|
@@ -1379,12 +1402,11 @@ def __call__(self, *args, **kwargs): | |
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the | ||
corresponding token in the sentence. | ||
""" | ||
inputs = self._args_parser(*args, **kwargs) | ||
|
||
inputs, offset_mappings = self._args_parser(*args, **kwargs) | ||
answers = [] | ||
|
||
for i, sentence in enumerate(inputs): | ||
if "offset_mapping" in kwargs: | ||
offset_mapping = kwargs["offset_mapping"][i] | ||
|
||
# Manage correct placement of the tensors | ||
with self.device_placement(): | ||
|
@@ -1397,9 +1419,13 @@ def __call__(self, *args, **kwargs): | |
return_special_tokens_mask=True, | ||
return_offsets_mapping=self.tokenizer.is_fast, | ||
) | ||
if "offset_mapping" in tokens: | ||
if self.tokenizer.is_fast: | ||
offset_mapping = tokens["offset_mapping"].cpu().numpy()[0] | ||
del tokens["offset_mapping"] | ||
elif offset_mappings: | ||
offset_mapping = offset_mappings[i] | ||
else: | ||
raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter") | ||
special_tokens_mask = tokens["special_tokens_mask"].cpu().numpy()[0] | ||
del tokens["special_tokens_mask"] | ||
|
||
|
@@ -1426,15 +1452,14 @@ def __call__(self, *args, **kwargs): | |
] | ||
|
||
for idx, label_idx in filtered_labels_idx: | ||
start_ind, end_ind = offset_mapping[idx] | ||
word_ref = sentence[start_ind:end_ind] | ||
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] | ||
is_subword = len(word_ref) != len(word) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed |
||
|
||
if int(input_ids[idx]) == self.tokenizer.unk_token_id: | ||
if offset_mapping is not None: | ||
start_ind, end_ind = offset_mapping[idx] | ||
word = sentence[start_ind:end_ind] | ||
else: | ||
raise Exception("Use a fast tokenizer or provide offset_mapping parameter") | ||
else: | ||
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] | ||
word = word_ref | ||
is_subword = False | ||
|
||
entity = { | ||
"word": word, | ||
|
@@ -1443,6 +1468,9 @@ def __call__(self, *args, **kwargs): | |
"index": idx, | ||
} | ||
|
||
if self.grouped_entities and self.ignore_subwords: | ||
entity["is_subword"] = is_subword | ||
|
||
entities += [entity] | ||
|
||
# Append grouped entities | ||
|
@@ -1467,28 +1495,17 @@ def group_sub_entities(self, entities: List[dict]) -> dict: | |
entity = entities[0]["entity"].split("-")[-1] | ||
scores = np.nanmean([entity["score"] for entity in entities]) | ||
tokens = [entity["word"] for entity in entities] | ||
|
||
if self.tokenizer.is_fast: | ||
word = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(tokens)) | ||
else: | ||
word = self.tokenizer.convert_tokens_to_string(tokens) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
entity_group = { | ||
"entity_group": entity, | ||
"score": np.mean(scores), | ||
"word": self.convert_tokens_to_string(tokens), | ||
"word": word, | ||
} | ||
return entity_group | ||
|
||
def is_subword_fn(self, token: str) -> bool: | ||
if token.startswith("##"): | ||
return True | ||
return False | ||
|
||
def convert_tokens_to_string(self, tokens): | ||
""" Converts a sequence of tokens (string) in a single string. """ | ||
if hasattr(self.tokenizer, "convert_tokens_to_string"): | ||
# fast tokenizers dont have convert_tokens_to_string?! | ||
return self.tokenizer.convert_tokens_to_string(tokens) | ||
else: | ||
out_string = " ".join(tokens).replace(" ##", "").strip() | ||
return out_string | ||
|
||
def group_entities(self, entities: List[dict]) -> List[dict]: | ||
""" | ||
Find and group together the adjacent tokens with the same entity predicted. | ||
|
@@ -1500,18 +1517,13 @@ def group_entities(self, entities: List[dict]) -> List[dict]: | |
entity_groups = [] | ||
entity_group_disagg = [] | ||
|
||
if hasattr(self.tokenizer, "is_subword_fn"): | ||
is_subword_fn = self.tokenizer.is_subword_fn | ||
else: | ||
is_subword_fn = self.is_subword_fn | ||
|
||
if entities: | ||
last_idx = entities[-1]["index"] | ||
|
||
for entity in entities: | ||
|
||
is_last_idx = entity["index"] == last_idx | ||
is_subword = self.ignore_subwords and is_subword_fn(entity["word"]) | ||
is_subword = self.ignore_subwords and entity["is_subword"] | ||
if not entity_group_disagg: | ||
entity_group_disagg += [entity] | ||
if is_last_idx: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this to check
offset_mapping
if provided. (does a simple batch_size check)