Skip to content

Commit

Permalink
#516 fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 19, 2023
1 parent 4ba5efc commit 8c1635d
Show file tree
Hide file tree
Showing 13 changed files with 36 additions and 11 deletions.
1 change: 1 addition & 0 deletions arekit/common/docs/parsed/providers/entity_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class EntityServiceProvider(BaseParsedDocumentServiceProvider):
NAME = "entity-service-provider"

def __init__(self, entity_index_func):
assert(callable(entity_index_func))
super(EntityServiceProvider, self).__init__(entity_index_func=entity_index_func)
# Initialize API.
self.__iter_raw_terms_func = None
Expand Down
10 changes: 6 additions & 4 deletions arekit/common/opinions/annot/algo/pair_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class PairBasedOpinionAnnotationAlgorithm(BaseOpinionAnnotationAlgorithm):
[1] Extracting Sentiment Attitudes from Analytical Texts https://arxiv.org/pdf/1808.08932.pdf
"""

def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, is_entity_ignored_func=None):
def __init__(self, dist_in_terms_bound, label_provider, entity_index_func, dist_in_sents=0,
is_entity_ignored_func=None):
"""
dist_in_terms_bound: int
max allowed distance in term (less than passed value)
Expand All @@ -25,13 +26,15 @@ def __init__(self, dist_in_terms_bound, label_provider, dist_in_sents=0, is_enti
"""
assert(isinstance(dist_in_terms_bound, int) or dist_in_terms_bound is None)
assert(isinstance(label_provider, BasePairLabelProvider))
assert(callable(entity_index_func))
assert(isinstance(dist_in_sents, int))
assert(callable(is_entity_ignored_func) or is_entity_ignored_func is None)

self.__label_provider = label_provider
self.__dist_in_terms_bound = dist_in_terms_bound
self.__dist_in_sents = dist_in_sents
self.__is_entity_ignored_func = is_entity_ignored_func
self.__entity_index_func = entity_index_func

# region private methods

Expand Down Expand Up @@ -87,9 +90,8 @@ def __filter_pair_func(e1, e2):
return key is not None

# Initialize providers.
# TODO. Provide here service #245 issue.
opinions_provider = OpinionPairsProvider(entity_index_func=None)
entity_service_provider = EntityServiceProvider(entity_index_func=None)
opinions_provider = OpinionPairsProvider(entity_index_func=self.__entity_index_func)
entity_service_provider = EntityServiceProvider(entity_index_func=self.__entity_index_func)
opinions_provider.init_parsed_doc(parsed_doc)
entity_service_provider.init_parsed_doc(parsed_doc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ def create_text_relation_extraction_pipeline(nerel_version,
DataType.Train: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
DataType.Test: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
DataType.Dev: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ def create_text_relation_extraction_pipeline(nerel_bio_version,
DataType.Train: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
DataType.Test: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
DataType.Dev: text_opinion_extraction_pipeline(text_parser=text_parser,
get_doc_by_id_func=doc_ops.by_id,
annotators=[predefined_annot],
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def create_text_opinion_extraction_pipeline(text_parser,
DistanceLimitedTextOpinionFilter(terms_per_context)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)

return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def create_text_opinion_extraction_pipeline(rusentrel_version,
DistanceLimitedTextOpinionFilter(terms_per_context)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)

return pipeline
Expand All @@ -69,7 +70,8 @@ def nolabel_annotator(synonyms, terms_per_context, dist_in_sentences=0, no_label
return AlgorithmBasedTextOpinionAnnotator(
annot_algo=PairBasedOpinionAnnotationAlgorithm(dist_in_sents=dist_in_sentences,
dist_in_terms_bound=terms_per_context,
label_provider=ConstantLabelProvider(no_label)),
label_provider=ConstantLabelProvider(no_label),
entity_index_func=lambda brat_entity: brat_entity.ID),
create_empty_collection_func=lambda: OpinionCollection(
synonyms=synonyms, error_on_duplicates=True, error_on_synonym_end_missed=False),
value_to_group_id_func=lambda value:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def create_nolabel_text_opinion_annotator(terms_per_context, no_label, dist_in_s
annot_algo=PairBasedOpinionAnnotationAlgorithm(
dist_in_sents=dist_in_sents,
dist_in_terms_bound=terms_per_context,
label_provider=ConstantLabelProvider(no_label)),
label_provider=ConstantLabelProvider(no_label),
entity_index_func=lambda brat_entity: brat_entity.ID),
create_empty_collection_func=lambda: OpinionCollection(
synonyms=synonyms,
error_on_duplicates=True,
Expand All @@ -152,6 +153,7 @@ def create_main_pipeline(text_parser, doc_provider, annotators, text_opinion_fil
get_doc_by_id_func=doc_provider.by_id,
text_parser=text_parser,
annotators=annotators,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_opinion_filters=text_opinion_filters)


Expand Down
10 changes: 6 additions & 4 deletions arekit/contrib/utils/pipelines/text_opinion/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from arekit.contrib.utils.pipelines.text_opinion.filters.limitation import FrameworkLimitationsTextOpinionFilter


def __iter_text_opinion_linkages(parsed_doc, annotators, text_opinion_filters):
def __iter_text_opinion_linkages(parsed_doc, annotators, entity_index_func, text_opinion_filters):
assert(isinstance(annotators, list))
assert(isinstance(parsed_doc, ParsedDocument))
assert(isinstance(text_opinion_filters, list))

def __to_id(text_opinion):
return "{}_{}".format(text_opinion.SourceId, text_opinion.TargetId)

service = ParsedDocumentService(parsed_doc=parsed_doc, providers=[EntityServiceProvider(None)])
service = ParsedDocumentService(parsed_doc=parsed_doc, providers=[EntityServiceProvider(entity_index_func)])
esp = service.get_provider(EntityServiceProvider.NAME)

predefined = set()
Expand Down Expand Up @@ -52,7 +52,8 @@ def __to_id(text_opinion):
yield text_opinion_linkage


def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators, text_opinion_filters=None):
def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators, entity_index_func,
text_opinion_filters=None):
assert(isinstance(text_parser, BaseTextParser))
assert(callable(get_doc_by_id_func))
assert(isinstance(annotators, list))
Expand All @@ -71,7 +72,8 @@ def text_opinion_extraction_pipeline(text_parser, get_doc_by_id_func, annotators

# (parsed_doc) -> (text_opinions)
MapPipelineItem(map_func=lambda parsed_doc: __iter_text_opinion_linkages(
annotators=annotators, parsed_doc=parsed_doc, text_opinion_filters=actual_text_opinion_filters)),
annotators=annotators, parsed_doc=parsed_doc, entity_index_func=entity_index_func,
text_opinion_filters=actual_text_opinion_filters)),

# linkages[] -> linkages
FlattenIterPipelineItem()
Expand Down
1 change: 1 addition & 0 deletions tests/contrib/utils/test_csv_stream_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __launch(self, writer, target_extention):
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)
#####

Expand Down
1 change: 1 addition & 0 deletions tests/tutorials/test_tutorial_pipeline_sampling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test(self):
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)
#####

Expand Down
1 change: 1 addition & 0 deletions tests/tutorials/test_tutorial_pipeline_sampling_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test(self):
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)
#####

Expand Down
1 change: 1 addition & 0 deletions tests/tutorials/test_tutorial_pipeline_sampling_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test(self):
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=lambda brat_entity: brat_entity.ID,
text_parser=text_parser)
#####

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ def test(self):

synonyms = StemmerBasedSynonymCollection(stemmer=MystemWrapper(), is_read_only=False)

# How we basically index the source entities.
entity_index_func = lambda brat_entity: brat_entity.ID

nolabel_annotator = AlgorithmBasedTextOpinionAnnotator(
annot_algo=PairBasedOpinionAnnotationAlgorithm(
dist_in_sents=0,
dist_in_terms_bound=50,
label_provider=ConstantLabelProvider(NoLabel())),
label_provider=ConstantLabelProvider(NoLabel()),
entity_index_func=entity_index_func),
create_empty_collection_func=lambda: OpinionCollection(
synonyms=synonyms, error_on_duplicates=True, error_on_synonym_end_missed=False),
value_to_group_id_func=lambda value:
Expand All @@ -77,6 +81,7 @@ def test(self):
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_by_id_func=doc_provider.by_id,
entity_index_func=entity_index_func,
text_parser=text_parser)

# Running the pipeline.
Expand Down

0 comments on commit 8c1635d

Please sign in to comment.