From 8c1635d49f185072db14677fdc08a318168e587a Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Sat, 19 Aug 2023 11:09:30 +0100 Subject: [PATCH] #516 fixed --- arekit/common/docs/parsed/providers/entity_service.py | 1 + arekit/common/opinions/annot/algo/pair_based.py | 10 ++++++---- .../pipelines/sources/nerel/extract_text_relations.py | 3 +++ .../sources/nerel_bio/extrat_text_relations.py | 3 +++ .../sources/ruattitudes/extract_text_opinions.py | 1 + .../sources/rusentrel/extract_text_opinions.py | 4 +++- .../sources/sentinerel/extract_text_opinions.py | 4 +++- .../contrib/utils/pipelines/text_opinion/extraction.py | 10 ++++++---- tests/contrib/utils/test_csv_stream_write.py | 1 + .../tutorials/test_tutorial_pipeline_sampling_bert.py | 1 + .../test_tutorial_pipeline_sampling_network.py | 1 + .../test_tutorial_pipeline_sampling_prompt.py | 1 + .../test_tutorial_pipeline_text_opinion_annotation.py | 7 ++++++- 13 files changed, 36 insertions(+), 11 deletions(-) diff --git a/arekit/common/docs/parsed/providers/entity_service.py b/arekit/common/docs/parsed/providers/entity_service.py index 28053030..9e02e4bc 100644 --- a/arekit/common/docs/parsed/providers/entity_service.py +++ b/arekit/common/docs/parsed/providers/entity_service.py @@ -42,6 +42,7 @@ class EntityServiceProvider(BaseParsedDocumentServiceProvider): NAME = "entity-service-provider" def __init__(self, entity_index_func): + assert(callable(entity_index_func)) super(EntityServiceProvider, self).__init__(entity_index_func=entity_index_func) # Initialize API. self.__iter_raw_terms_func = None diff --git a/arekit/common/opinions/annot/algo/pair_based.py b/arekit/common/opinions/annot/algo/pair_based.py index 8d82a42f..31df2d6f 100644 --- a/arekit/common/opinions/annot/algo/pair_based.py +++ b/arekit/common/opinions/annot/algo/pair_based.py @@ -16,7 +16,8 @@ class PairBasedOpinionAnnotationAlgorithm(BaseOpinionAnnotationAlgorithm): [1] Extracting Sentiment Attitudes from Analytical Texts https://arxiv.org/pdf/1808.08932.pdf """ - def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, is_entity_ignored_func=None): + def __init__(self, dist_in_terms_bound, label_provider, entity_index_func, dist_in_sents=0, + is_entity_ignored_func=None): """ dist_in_terms_bound: int max allowed distance in term (less than passed value) @@ -25,6 +26,7 @@ def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, is_enti """ assert(isinstance(dist_in_terms_bound, int) or dist_in_terms_bound is None) assert(isinstance(label_provider, BasePairLabelProvider)) + assert(callable(entity_index_func)) assert(isinstance(dist_in_sents, int)) assert(callable(is_entity_ignored_func) or is_entity_ignored_func is None) @@ -32,6 +34,7 @@ def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, is_enti self.__dist_in_terms_bound = dist_in_terms_bound self.__dist_in_sents = dist_in_sents self.__is_entity_ignored_func = is_entity_ignored_func + self.__entity_index_func = entity_index_func # region private methods @@ -87,9 +90,8 @@ def __filter_pair_func(e1, e2): return key is not None # Initialize providers. - # TODO. Provide here service #245 issue. - opinions_provider = OpinionPairsProvider(entity_index_func=None) - entity_service_provider = EntityServiceProvider(entity_index_func=None) + opinions_provider = OpinionPairsProvider(entity_index_func=self.__entity_index_func) + entity_service_provider = EntityServiceProvider(entity_index_func=self.__entity_index_func) opinions_provider.init_parsed_doc(parsed_doc) entity_service_provider.init_parsed_doc(parsed_doc) diff --git a/arekit/contrib/utils/pipelines/sources/nerel/extract_text_relations.py b/arekit/contrib/utils/pipelines/sources/nerel/extract_text_relations.py index 1ede0b5a..8b510496 100644 --- a/arekit/contrib/utils/pipelines/sources/nerel/extract_text_relations.py +++ b/arekit/contrib/utils/pipelines/sources/nerel/extract_text_relations.py @@ -39,14 +39,17 @@ def create_text_relation_extraction_pipeline(nerel_version, DataType.Train: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), DataType.Test: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), DataType.Dev: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), } diff --git a/arekit/contrib/utils/pipelines/sources/nerel_bio/extrat_text_relations.py b/arekit/contrib/utils/pipelines/sources/nerel_bio/extrat_text_relations.py index e5b72972..69d84b3e 100644 --- a/arekit/contrib/utils/pipelines/sources/nerel_bio/extrat_text_relations.py +++ b/arekit/contrib/utils/pipelines/sources/nerel_bio/extrat_text_relations.py @@ -39,14 +39,17 @@ def create_text_relation_extraction_pipeline(nerel_bio_version, DataType.Train: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), DataType.Test: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), DataType.Dev: text_opinion_extraction_pipeline(text_parser=text_parser, get_doc_by_id_func=doc_ops.by_id, annotators=[predefined_annot], + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters), } diff --git a/arekit/contrib/utils/pipelines/sources/ruattitudes/extract_text_opinions.py b/arekit/contrib/utils/pipelines/sources/ruattitudes/extract_text_opinions.py index 7486c2b9..6a04f0fc 100644 --- a/arekit/contrib/utils/pipelines/sources/ruattitudes/extract_text_opinions.py +++ b/arekit/contrib/utils/pipelines/sources/ruattitudes/extract_text_opinions.py @@ -54,6 +54,7 @@ def create_text_opinion_extraction_pipeline(text_parser, DistanceLimitedTextOpinionFilter(terms_per_context) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) return pipeline diff --git a/arekit/contrib/utils/pipelines/sources/rusentrel/extract_text_opinions.py b/arekit/contrib/utils/pipelines/sources/rusentrel/extract_text_opinions.py index cc117dac..2a4cd222 100644 --- a/arekit/contrib/utils/pipelines/sources/rusentrel/extract_text_opinions.py +++ b/arekit/contrib/utils/pipelines/sources/rusentrel/extract_text_opinions.py @@ -56,6 +56,7 @@ def create_text_opinion_extraction_pipeline(rusentrel_version, DistanceLimitedTextOpinionFilter(terms_per_context) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) return pipeline @@ -69,7 +70,8 @@ def nolabel_annotator(synonyms, terms_per_context, dist_in_sentences=0, no_label return AlgorithmBasedTextOpinionAnnotator( annot_algo=PairBasedOpinionAnnotationAlgorithm(dist_in_sents=dist_in_sentences, dist_in_terms_bound=terms_per_context, - label_provider=ConstantLabelProvider(no_label)), + label_provider=ConstantLabelProvider(no_label), + entity_index_func=lambda brat_entity: brat_entity.ID), create_empty_collection_func=lambda: OpinionCollection( synonyms=synonyms, error_on_duplicates=True, error_on_synonym_end_missed=False), value_to_group_id_func=lambda value: diff --git a/arekit/contrib/utils/pipelines/sources/sentinerel/extract_text_opinions.py b/arekit/contrib/utils/pipelines/sources/sentinerel/extract_text_opinions.py index ab3778eb..4b26f610 100644 --- a/arekit/contrib/utils/pipelines/sources/sentinerel/extract_text_opinions.py +++ b/arekit/contrib/utils/pipelines/sources/sentinerel/extract_text_opinions.py @@ -137,7 +137,8 @@ def create_nolabel_text_opinion_annotator(terms_per_context, no_label, dist_in_s annot_algo=PairBasedOpinionAnnotationAlgorithm( dist_in_sents=dist_in_sents, dist_in_terms_bound=terms_per_context, - label_provider=ConstantLabelProvider(no_label)), + label_provider=ConstantLabelProvider(no_label), + entity_index_func=lambda brat_entity: brat_entity.ID), create_empty_collection_func=lambda: OpinionCollection( synonyms=synonyms, error_on_duplicates=True, @@ -152,6 +153,7 @@ def create_main_pipeline(text_parser, doc_provider, annotators, text_opinion_fil get_doc_by_id_func=doc_provider.by_id, text_parser=text_parser, annotators=annotators, + entity_index_func=lambda brat_entity: brat_entity.ID, text_opinion_filters=text_opinion_filters) diff --git a/arekit/contrib/utils/pipelines/text_opinion/extraction.py b/arekit/contrib/utils/pipelines/text_opinion/extraction.py index aaf84eef..af8e95b6 100644 --- a/arekit/contrib/utils/pipelines/text_opinion/extraction.py +++ b/arekit/contrib/utils/pipelines/text_opinion/extraction.py @@ -13,7 +13,7 @@ from arekit.contrib.utils.pipelines.text_opinion.filters.limitation import FrameworkLimitationsTextOpinionFilter -def __iter_text_opinion_linkages(parsed_doc, annotators, text_opinion_filters): +def __iter_text_opinion_linkages(parsed_doc, annotators, entity_index_func, text_opinion_filters): assert(isinstance(annotators, list)) assert(isinstance(parsed_doc, ParsedDocument)) assert(isinstance(text_opinion_filters, list)) @@ -21,7 +21,7 @@ def __iter_text_opinion_linkages(parsed_doc, annotators, text_opinion_filters): def __to_id(text_opinion): return "{}_{}".format(text_opinion.SourceId, text_opinion.TargetId) - service = ParsedDocumentService(parsed_doc=parsed_doc, providers=[EntityServiceProvider(None)]) + service = ParsedDocumentService(parsed_doc=parsed_doc, providers=[EntityServiceProvider(entity_index_func)]) esp = service.get_provider(EntityServiceProvider.NAME) predefined = set() @@ -52,7 +52,8 @@ def __to_id(text_opinion): yield text_opinion_linkage -def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators, text_opinion_filters=None): +def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators, entity_index_func, + text_opinion_filters=None): assert(isinstance(text_parser, BaseTextParser)) assert(callable(get_doc_by_id_func)) assert(isinstance(annotators, list)) @@ -71,7 +72,8 @@ def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators # (parsed_doc) -> (text_opinions) MapPipelineItem(map_func=lambda parsed_doc: __iter_text_opinion_linkages( - annotators=annotators, parsed_doc=parsed_doc, text_opinion_filters=actual_text_opinion_filters)), + annotators=annotators, parsed_doc=parsed_doc, entity_index_func=entity_index_func, + text_opinion_filters=actual_text_opinion_filters)), # linkages[] -> linkages FlattenIterPipelineItem() diff --git a/tests/contrib/utils/test_csv_stream_write.py b/tests/contrib/utils/test_csv_stream_write.py index ab2b41d1..04ace826 100644 --- a/tests/contrib/utils/test_csv_stream_write.py +++ b/tests/contrib/utils/test_csv_stream_write.py @@ -77,6 +77,7 @@ def __launch(self, writer, target_extention): DistanceLimitedTextOpinionFilter(terms_per_context=50) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) ##### diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py index b00c934e..844ce370 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_bert.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_bert.py @@ -116,6 +116,7 @@ def test(self): DistanceLimitedTextOpinionFilter(terms_per_context=50) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) ##### diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_network.py b/tests/tutorials/test_tutorial_pipeline_sampling_network.py index 8233217b..880c2257 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_network.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_network.py @@ -120,6 +120,7 @@ def test(self): DistanceLimitedTextOpinionFilter(terms_per_context=50) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) ##### diff --git a/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py b/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py index 446de301..e39b2181 100644 --- a/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py +++ b/tests/tutorials/test_tutorial_pipeline_sampling_prompt.py @@ -117,6 +117,7 @@ def test(self): DistanceLimitedTextOpinionFilter(terms_per_context=50) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=lambda brat_entity: brat_entity.ID, text_parser=text_parser) ##### diff --git a/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py b/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py index d67efcf7..16fc90a1 100644 --- a/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py +++ b/tests/tutorials/test_tutorial_pipeline_text_opinion_annotation.py @@ -52,11 +52,15 @@ def test(self): synonyms = StemmerBasedSynonymCollection(stemmer=MystemWrapper(), is_read_only=False) + # How we basically index the source entities. + entity_index_func = lambda brat_entity: brat_entity.ID + nolabel_annotator = AlgorithmBasedTextOpinionAnnotator( annot_algo=PairBasedOpinionAnnotationAlgorithm( dist_in_sents=0, dist_in_terms_bound=50, - label_provider=ConstantLabelProvider(NoLabel())), + label_provider=ConstantLabelProvider(NoLabel()), + entity_index_func=entity_index_func), create_empty_collection_func=lambda: OpinionCollection( synonyms=synonyms, error_on_duplicates=True, error_on_synonym_end_missed=False), value_to_group_id_func=lambda value: @@ -77,6 +81,7 @@ def test(self): DistanceLimitedTextOpinionFilter(terms_per_context=50) ], get_doc_by_id_func=doc_provider.by_id, + entity_index_func=entity_index_func, text_parser=text_parser) # Running the pipeline.