diff --git a/pie_extended/tagger.py b/pie_extended/tagger.py index 793fdce..8c2dc9d 100644 --- a/pie_extended/tagger.py +++ b/pie_extended/tagger.py @@ -54,43 +54,30 @@ def iter_tag(self, data: str, iterator: DataIterator, formatter_class: type): # Unzip the batch into the sentences, their sizes and the dictionaries of things that needs # to be reinserted sents, lengths, needs_reinsertion = zip(*chunk) - # Removing punctuation might create empty sentences ! - # Which would crash Torch - empty_sents_indexes = { - index: [] - for index, sent in enumerate(sents) - if len(sent) == 0 - } + + is_empty = [0 == len(sent) for sent in enumerate(sents)] + tagged, tasks = self.tag( - sents=[sent for sent in sents if len(sent)], + sents=[sent for sent in sents if sent], lengths=lengths ) formatter: Formatter = formatter_class(tasks) # We keep a real sentence index - real_sentence_index = 0 - for sent in tagged: - if not sent: - continue + for sents_index, sent_is_empty in enumerate(is_empty): + if sent_is_empty: + sent = [] + else: + sent = tagged.pop(0) + # Gets things that needs to be reinserted - sent_reinsertion = needs_reinsertion[real_sentence_index] + sent_reinsertion = needs_reinsertion[sents_index] # If the header has not yet be written, write it if not header: yield formatter.write_headers() header = True - # Some sentences can be empty and would have been removed from tagging - # we check and until we get to a non empty sentence - # we increment the real_sentence_index to keep in check with the reinsertion map - while real_sentence_index in empty_sents_indexes: - yield from self.reinsert_full( - formatter, - needs_reinsertion[real_sentence_index], - tasks - ) - real_sentence_index += 1 - yield formatter.write_sentence_beginning() # If we have a disambiguator, we run the results into it @@ -98,7 +85,6 @@ def iter_tag(self, data: str, iterator: DataIterator, formatter_class: type): sent = self.disambiguation(sent, tasks) reinsertion_index = 0 - index = 0 for index, (token, tags) in enumerate(sent): while reinsertion_index + index in sent_reinsertion: @@ -125,15 +111,6 @@ def iter_tag(self, data: str, iterator: DataIterator, formatter_class: type): yield formatter.write_sentence_end() - real_sentence_index += 1 - - while real_sentence_index in empty_sents_indexes: - yield from self.reinsert_full( - formatter, - needs_reinsertion[real_sentence_index], - tasks - ) - real_sentence_index += 1 if formatter: yield formatter.write_footer()