Skip to content

Commit

Permalink
#236 done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 30, 2021
1 parent 44cd493 commit 53b5688
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 114 deletions.
114 changes: 23 additions & 91 deletions arekit/common/data/input/providers/opinions.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,37 @@
import collections

from arekit.common.data.input.sample import InputSampleBase
from arekit.common.experiment.pipelines.text_opinoins_input import process_input_text_opinions
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider
from arekit.common.news.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext


class OpinionProvider(object):
"""
TextOpinion iterator.
- Filter text_opinions by provided func.
- Assigns the IDs.
"""

def __init__(self, text_opinions_linkages_it_func):
assert(callable(text_opinions_linkages_it_func))
self.__text_opinions_linkages_it_func = text_opinions_linkages_it_func

# region private methods

@staticmethod
def __iter_linked_text_opinion_lists(
text_opinion_pairs_provider,
iter_opins_for_extraction,
filter_text_opinion_func):

assert (isinstance(text_opinion_pairs_provider, TextOpinionPairsProvider))
assert (isinstance(iter_opins_for_extraction, collections.Iterable))
assert (callable(filter_text_opinion_func))

for opinion in iter_opins_for_extraction:
linked_text_opinions = TextOpinionsLinkage(text_opinion_pairs_provider.iter_from_opinion(opinion))
filtered_text_opinions = list(filter(filter_text_opinion_func, linked_text_opinions))

if len(filtered_text_opinions) == 0:
continue

yield filtered_text_opinions

@staticmethod
def __iter_linked_text_opins(news_opins_for_extraction_func, parse_news_func,
value_to_group_id_func, terms_per_context, doc_ids_it):
"""
Extracting text-level opinions based on doc-level opinions in documents,
obtained by information in experiment.
NOTE:
1. Assumes to provide the same label (doc level opinion) onto related text-level opinions.
"""
assert(callable(parse_news_func))
assert(callable(value_to_group_id_func))
assert(isinstance(doc_ids_it, collections.Iterable))

curr_id = 0

value_to_group_id_func = value_to_group_id_func

for doc_id in doc_ids_it:

parsed_news = parse_news_func(doc_id)
entity_service = EntityServiceProvider(parsed_news)

linked_text_opinion_lists = OpinionProvider.__iter_linked_text_opinion_lists(
# TODO. To be refactored.
text_opinion_pairs_provider=TextOpinionPairsProvider(
parsed_news=parsed_news,
value_to_group_id_func=value_to_group_id_func),
iter_opins_for_extraction=news_opins_for_extraction_func(doc_id=parsed_news.RelatedDocID),
filter_text_opinion_func=lambda text_opinion: InputSampleBase.check_ability_to_create_sample(
entity_service=entity_service,
text_opinion=text_opinion,
window_size=terms_per_context))

for linked_text_opinion_list in linked_text_opinion_lists:

# Assign IDs.
for text_opinion in linked_text_opinion_list:
text_opinion.set_text_opinion_id(curr_id)
curr_id += 1

yield parsed_news, TextOpinionsLinkage(linked_text_opinion_list)
def __init__(self, pipeline):
assert(isinstance(pipeline, BasePipeline))
self.__pipeline = pipeline

# endregion

@classmethod
def create(cls, iter_news_opins_for_extraction, value_to_group_id_func,
def create(cls, iter_doc_opins, value_to_group_id_func,
parse_news_func, terms_per_context):
assert(callable(iter_news_opins_for_extraction))
assert(callable(iter_doc_opins))
assert(callable(value_to_group_id_func))
assert(isinstance(terms_per_context, int))
assert(callable(parse_news_func))

def it_func(doc_ids_it):
return cls.__iter_linked_text_opins(
value_to_group_id_func=value_to_group_id_func,
news_opins_for_extraction_func=lambda doc_id: iter_news_opins_for_extraction(doc_id=doc_id),
terms_per_context=terms_per_context,
doc_ids_it=doc_ids_it,
parse_news_func=lambda doc_id: parse_news_func(doc_id))

return cls(text_opinions_linkages_it_func=it_func)

def iter_linked_opinions(self, doc_ids_it):
return self.__text_opinions_linkages_it_func(doc_ids_it)
pipeline = process_input_text_opinions(
parse_news_func=parse_news_func,
value_to_group_id_func=value_to_group_id_func,
iter_doc_opins=iter_doc_opins,
terms_per_context=terms_per_context)

return cls(pipeline)

def iter_linked_opinions(self, doc_ids):
ctx = PipelineContext({"src": doc_ids})
self.__pipeline.run(ctx)
for linkage in ctx.provide("src"):
assert(isinstance(linkage, TextOpinionsLinkage))
print(linkage.Tag)
yield linkage
10 changes: 8 additions & 2 deletions arekit/common/data/input/providers/rows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging

from arekit.common.data.input.providers.opinions import OpinionProvider
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider

logger = logging.getLogger(__name__)
Expand All @@ -22,10 +23,15 @@ def iter_by_rows(self, opinion_provider, doc_ids_iter, idle_mode):
assert(isinstance(opinion_provider, OpinionProvider))
assert(isinstance(doc_ids_iter, collections.Iterable))

for parsed_news, linkage in opinion_provider.iter_linked_opinions(doc_ids_iter):
for linkage in opinion_provider.iter_linked_opinions(doc_ids_iter):
assert(isinstance(linkage, TextOpinionsLinkage))

parsed_news = linkage.Tag
# NOTE: Double parsing.
entity_service = EntityServiceProvider(parsed_news)

rows_it = self._provide_rows(parsed_news=parsed_news,
entity_service=EntityServiceProvider(parsed_news),
entity_service=entity_service,
text_opinion_linkage=linkage,
idle_mode=idle_mode)

Expand Down
60 changes: 43 additions & 17 deletions arekit/common/experiment/pipelines/text_opinoins_input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from arekit.common.data.input.sample import InputSampleBase
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider
from arekit.common.news.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider
from arekit.common.opinions.base import Opinion
Expand All @@ -9,54 +10,79 @@
from arekit.common.text_opinions.base import TextOpinion


def to_text_opinions_iter(provider, opinions, filter_func):
def to_text_opinon_linkages(provider, opinions, tag_value_func, filter_func):
assert(isinstance(provider, TextOpinionPairsProvider))
assert(callable(tag_value_func))
assert(callable(filter_func))

for opinion in opinions:
assert(isinstance(opinion, Opinion))

text_opinions = []

for text_opinion in provider.iter_from_opinion(opinion):
assert(isinstance(text_opinion, TextOpinion))

if not filter_func(text_opinion):
continue
yield text_opinion

text_opinions.append(text_opinion)

if len(text_opinions) == 0:
continue

linkage = TextOpinionsLinkage(text_opinions)

if tag_value_func is not None:
linkage.set_tag(tag_value_func(linkage))

yield linkage


def process_input_text_opinions(parse_news_func, value_to_group_id_func, terms_per_context):
def process_input_text_opinions(parse_news_func, iter_doc_opins,
value_to_group_id_func, terms_per_context):
""" Opinion collection generation pipeline.
"""

def __assign_ids(text_opinion, curr_id_list):
assert(isinstance(text_opinion, TextOpinion))
current_id = curr_id_list[0]
text_opinion.set_text_opinion_id(current_id)
curr_id_list[0] += 1
def __assign_ids(linkage, curr_id_list):
assert(isinstance(linkage, TextOpinionsLinkage))
for text_opinion in linkage:
assert(isinstance(text_opinion, TextOpinion))
current_id = curr_id_list[0]
text_opinion.set_text_opinion_id(current_id)
curr_id_list[0] += 1

# List that allows to pass and modify int (current id) into function.
curr_id = [0]

return BasePipeline([
# (id) -> (id, opinions)
MapPipelineItem(map_func=lambda doc_id: (doc_id, list(iter_doc_opins(doc_id)))),

# (id, opinions) -> (parsed_news, opinions).
MapPipelineItem(map_func=lambda data: (parse_news_func(data[0]), data[1]) ),

# (parsed_news, opinions) -> (opins_provider, entities_provider, opinions).
MapPipelineItem(map_func=lambda data: (
data[0],
TextOpinionPairsProvider(parsed_news=data[0], value_to_group_id_func=value_to_group_id_func),
EntityServiceProvider(parsed_news=data[0]),
data[1])),

# (opins_provider, entities_provider, opinions) -> text_opinions[].
MapPipelineItem(map_func=lambda data: to_text_opinions_iter(
provider=data[0],
opinions=data[2],
# (opins_provider, entities_provider, opinions) -> linkages[].
MapPipelineItem(map_func=lambda data: to_text_opinon_linkages(
provider=data[1],
opinions=data[3],
# Assign parsed news.
tag_value_func=lambda linkage: data[0],
filter_func=lambda text_opinion: InputSampleBase.check_ability_to_create_sample(
entity_service=data[1],
entity_service=data[2],
text_opinion=text_opinion,
window_size=terms_per_context))),

# text_opinions[] -> text_opinions.
# linkages[] -> linkages.
FlattenIterPipelineItem(),

# Assign id.
HandleIterPipelineItem(handle_func=lambda text_opinion: __assign_ids(text_opinion=text_opinion,
curr_id_list=curr_id))
])
HandleIterPipelineItem(handle_func=lambda linkage: __assign_ids(linkage=linkage, curr_id_list=curr_id))
])
8 changes: 8 additions & 0 deletions arekit/common/linkage/text_opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@ class TextOpinionsLinkage(LinkedDataWrapper):

def __init__(self, text_opinions_it):
super(TextOpinionsLinkage, self).__init__(linked_data=text_opinions_it)
self.__tag = None

def set_tag(self, value):
self.__tag = value

@property
def First(self):
first = super(TextOpinionsLinkage, self).First
assert(isinstance(first, TextOpinion))
return first

@property
def Tag(self):
return self.__tag

@property
def RelatedDocID(self):
return self.First.DocID
Expand Down
4 changes: 2 additions & 2 deletions arekit/common/pipeline/item_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class FlattenIterPipelineItem(BasePipelineItem):
""" Considered to flat iterations of items that represent iterations.
"""

def __flat(self, iter_data):
def __flat_iter(self, iter_data):
for iter_item in iter_data:
for item in iter_item:
yield item
Expand All @@ -17,4 +17,4 @@ def apply(self, pipeline_ctx):
assert (isinstance(pipeline_ctx, PipelineContext))
iter_data = pipeline_ctx.provide("src")
assert (isinstance(iter_data, collections.Iterable))
pipeline_ctx.update(param="src", value=self.__flat(iter_data))
pipeline_ctx.update(param="src", value=self.__flat_iter(iter_data))
2 changes: 1 addition & 1 deletion arekit/contrib/bert/run_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __handle_iteration(self, data_type):
opinion_provider = OpinionProvider.create(
value_to_group_id_func=None,
parse_news_func=lambda doc_id: self._experiment.DocumentOperations.parse_doc(doc_id),
iter_news_opins_for_extraction=lambda doc_id:
iter_doc_opins=lambda doc_id:
self._experiment.OpinionOperations.iter_opinions_for_extraction(doc_id=doc_id, data_type=data_type),
terms_per_context=self._experiment.DataIO.TermsPerContext)

Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/input/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def prepare(experiment, terms_per_context, balance, value_to_group_id_func=None)
opinion_provider = OpinionProvider.create(
value_to_group_id_func=value_to_group_id_func, # TODO. Remove this parameter.
parse_news_func=lambda doc_id: experiment.DocumentOperations.parse_doc(doc_id),
iter_news_opins_for_extraction=lambda doc_id:
iter_doc_opins=lambda doc_id:
experiment.OpinionOperations.iter_opinions_for_extraction(doc_id=doc_id, data_type=data_type),
terms_per_context=terms_per_context)

Expand Down

0 comments on commit 53b5688

Please sign in to comment.