Skip to content

Commit

Permalink
#250 refactoring. Annot is a part of the pipeline now.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 28, 2022
1 parent 181d33a commit 745d663
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 151 deletions.
34 changes: 5 additions & 29 deletions arekit/common/experiment/annot/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import logging

from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.news.parsed.base import ParsedNews
from arekit.common.utils import progress_bar_iter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand All @@ -19,32 +13,14 @@ class BaseAnnotator(object):
def __init__(self):
logger.info("Init annotator: [{}]".format(self.__class__))

# region private methods

def __iter_annotated_collections(self, data_type, doc_ops, opin_ops):
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))

logged_parsed_news_iter = progress_bar_iter(
iterable=doc_ops.iter_parsed_docs(doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Annotate)),
desc="Annotating parsed news [{}]".format(data_type))

for parsed_news in logged_parsed_news_iter:
assert(isinstance(parsed_news, ParsedNews))
yield parsed_news.RelatedDocID, \
self._annot_collection_core(parsed_news=parsed_news, data_type=data_type, opin_ops=opin_ops)

# endregion

def _annot_collection_core(self, parsed_news, data_type, opin_ops):
raise NotImplementedError

# region public methods

def iter_annotated_collections(self, data_type, doc_ops, opin_ops):
assert(isinstance(opin_ops, OpinionOperations))
return self.__iter_annotated_collections(data_type=data_type,
doc_ops=doc_ops,
opin_ops=opin_ops)
def annotate_collection(self, data_type, doc_id, doc_ops, opin_ops):
parsed_news = doc_ops.parse_doc(doc_id)
return parsed_news.RelatedDocID, \
self._annot_collection_core(parsed_news=parsed_news, data_type=data_type, opin_ops=opin_ops)

# endregion
# endregion
9 changes: 0 additions & 9 deletions arekit/common/experiment/api/ops_opin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@ def LabelsFormatter(self):

# region extraction

# TODO. #248 remove doc_ops (see details in the issue description).
def iter_annot_collections(self, exp_ctx, doc_ops, data_type):

collections_it = exp_ctx.Annotator.iter_annotated_collections(
data_type=data_type, doc_ops=doc_ops, opin_ops=self)

for doc_id, collection in collections_it:
yield doc_id, collection

def iter_opinions_for_extraction(self, doc_id, data_type):
""" providing opinions for further context-level opinion extraction process.
in terms of sentiment attitude extraction, this is a general method
Expand Down
48 changes: 19 additions & 29 deletions arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from arekit.common.data.input.pipeline import text_opinions_iter_pipeline
from arekit.common.data.input.providers.columns.opinion import OpinionColumnsProvider
from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider
from arekit.common.data.input.providers.opinions import InputTextOpinionProvider
Expand All @@ -12,7 +11,10 @@
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.bert.samplers.factory import create_bert_sample_provider
from arekit.contrib.utils.pipeline import ppl_text_ids_to_parsed_news, ppl_parsed_news_to_opinion_linkages, \
ppl_text_ids_to_annotated


class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):
Expand Down Expand Up @@ -61,14 +63,22 @@ def __handle_iteration(self, data_type):
rows_provider=sample_rows_provider,
storage=BaseRowsStorage())

# TODO. #250. Expand this pipeline with the annotation (in advance).
# TODO. Check out the same comment at NetworkInputHelper.
pipeline = text_opinions_iter_pipeline(
parse_news_func=lambda doc_id: self.__doc_ops.parse_doc(doc_id),
value_to_group_id_func=self.__value_to_group_id_func,
iter_doc_opins=lambda doc_id: self.__opin_ops.iter_opinions_for_extraction(
doc_id=doc_id, data_type=data_type),
terms_per_context=self.__exp_ctx.TermsPerContext)
pipeline = BasePipeline(
ppl_text_ids_to_annotated(
annotator=self.__exp_ctx.Annotator,
data_type=data_type,
doc_ops=self.__doc_ops,
opin_ops=self.__opin_ops)
+
ppl_text_ids_to_parsed_news(
parse_news_func=lambda doc_id: self.__doc_ops.parse_doc(doc_id),
iter_doc_opins=lambda doc_id: self.__opin_ops.iter_opinions_for_extraction(
doc_id=doc_id, data_type=data_type))
+
ppl_parsed_news_to_opinion_linkages(
value_to_group_id_func=self.__value_to_group_id_func,
terms_per_context=self.__exp_ctx.TermsPerContext)
)

# Create opinion provider
opinion_provider = InputTextOpinionProvider(pipeline)
Expand Down Expand Up @@ -104,24 +114,4 @@ def on_iteration(self, iter_index):
for data_type in self.__exp_ctx.DataFolding.iter_supported_data_types():
self.__handle_iteration(data_type)

def on_before_iteration(self):
for data_type in self.__exp_ctx.DataFolding.iter_supported_data_types():

# TODO. #250. A part of the further pipeline.
# TODO. This might be included in InputTextOpinionProvider, as an initial operation
# TODO. In a whole pipeline. This code duplicates the one in NetworkInputHelper.
collections_it = self.__opin_ops.iter_annot_collections(
exp_ctx=self.__exp_ctx, doc_ops=self.__doc_ops, data_type=data_type)

for doc_id, collection in collections_it:

target = self.__exp_io.create_opinion_collection_target(
doc_id=doc_id,
data_type=data_type)

self.__exp_io.write_opinion_collection(
collection=collection,
target=target,
labels_formatter=self.__annot_label_formatter)

# endregion
9 changes: 5 additions & 4 deletions arekit/contrib/experiment_rusentrel/annot/two_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def _annot_collection_core(self, parsed_news, data_type, opin_ops):

# region public methods

def iter_annotated_collections(self, data_type, doc_ops, opin_ops):
def annotate_collection(self, data_type, doc_id, doc_ops, opin_ops):

if data_type == DataType.Train:
return
# Return empty collection.
return opin_ops.create_opinion_collection()

super(TwoScaleTaskAnnotator, self).iter_annotated_collections(
data_type, doc_ops=doc_ops, opin_ops=opin_ops)
super(TwoScaleTaskAnnotator, self).annotate_collection(
data_type, doc_ops=doc_ops, doc_id=doc_id, opin_ops=opin_ops)

# endregion
48 changes: 17 additions & 31 deletions arekit/contrib/networks/core/input/helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import collections
import logging

from arekit.common.data.input.pipeline import text_opinions_iter_pipeline
from arekit.common.data.input.providers.columns.opinion import OpinionColumnsProvider
from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider
from arekit.common.data.input.providers.opinions import InputTextOpinionProvider
Expand All @@ -10,6 +9,7 @@
from arekit.common.data.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.networks.core.input.ctx_serialization import NetworkSerializationContext
from arekit.contrib.networks.core.input.formatters.pos_mapper import PosTermsMapper
from arekit.contrib.networks.core.input.providers.sample import NetworkSampleRowProvider
Expand All @@ -18,6 +18,8 @@
from arekit.contrib.networks.core.input.terms_mapping import StringWithEmbeddingNetworkTermMapping
from arekit.contrib.networks.core.input.embedding.matrix import create_term_embedding_matrix
from arekit.contrib.networks.embeddings.base import Embedding
from arekit.contrib.utils.pipeline import ppl_parsed_news_to_opinion_linkages, ppl_text_ids_to_parsed_news, \
ppl_text_ids_to_annotated

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -89,17 +91,6 @@ def __perform_writing(exp_ctx, exp_io, doc_ops, data_type, opinion_provider,
opinions_repo.write(writer=exp_io.create_opinions_writer(),
target=exp_io.create_opinions_writer_target(data_type=data_type))

# TODO. #250 this part is not related to a helper.
# TODO. This is a particular implementation, which is considered to be
# TODO. Implemented at iter_opins_for_extraction (OpinionOperations)
@staticmethod
def __save_annotation(exp_io, labels_fmt, data_type, collections_it):
for doc_id, collection in collections_it:
target = exp_io.create_opinion_collection_target(doc_id=doc_id, data_type=data_type)
exp_io.write_opinion_collection(collection=collection,
target=target,
labels_formatter=labels_fmt)

# endregion

@staticmethod
Expand All @@ -124,25 +115,20 @@ def prepare(exp_ctx, exp_io, doc_ops, opin_ops, terms_per_context, balance, valu

for data_type in exp_ctx.DataFolding.iter_supported_data_types():

# Perform annotation
# TODO. #250. This should be transformed into pipeline element.
# TODO. And then combined (embedded) into pipeline below.
NetworkInputHelper.__save_annotation(
exp_io=exp_io,
labels_fmt=opin_ops.LabelsFormatter,
data_type=data_type,
collections_it=opin_ops.iter_annot_collections())

# TODO. #250. Organize a complete pipeline.
# TODO. Now InputTextOpinionProvider has only a part from annotated opinions
# TODO. We need to extend our pipeline with the related pre-processing (annotation).
# TODO. See text_opinions_iter_pipeline method.
pipeline = text_opinions_iter_pipeline(
parse_news_func=lambda doc_id: doc_ops.parse_doc(doc_id),
value_to_group_id_func=value_to_group_id_func,
iter_doc_opins=lambda doc_id:
opin_ops.iter_opinions_for_extraction(doc_id=doc_id, data_type=data_type),
terms_per_context=terms_per_context)
pipeline = BasePipeline(
ppl_text_ids_to_annotated(annotator=exp_ctx.Annotator,
data_type=data_type,
doc_ops=doc_ops,
opin_ops=opin_ops)
+
ppl_text_ids_to_parsed_news(
parse_news_func=lambda doc_id: doc_ops.parse_doc(doc_id),
iter_doc_opins=lambda doc_id: opin_ops.iter_opinions_for_extraction(
doc_id=doc_id, data_type=data_type))
+
ppl_parsed_news_to_opinion_linkages(value_to_group_id_func=value_to_group_id_func,
terms_per_context=terms_per_context)
)

NetworkInputHelper.__perform_writing(
exp_ctx=exp_ctx,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,70 +1,46 @@
from arekit.common.data.input.sample import InputSampleBase
from arekit.common.experiment.annot.base import BaseAnnotator
from arekit.common.linkage.text_opinions import TextOpinionsLinkage
from arekit.common.news.parsed.providers.entity_service import EntityServiceProvider
from arekit.common.news.parsed.providers.text_opinion_pairs import TextOpinionPairsProvider
from arekit.common.news.parsed.service import ParsedNewsService
from arekit.common.opinions.base import Opinion
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.items.flatten import FlattenIterPipelineItem
from arekit.common.pipeline.item_map import MapPipelineItem
from arekit.common.pipeline.items.flatten import FlattenIterPipelineItem
from arekit.common.text_opinions.base import TextOpinion


def __to_text_opinion_linkages(provider, opinions, tag_value_func, filter_func):
assert(isinstance(provider, TextOpinionPairsProvider))
assert(callable(tag_value_func))
assert(callable(filter_func))

for opinion in opinions:
assert(isinstance(opinion, Opinion))

text_opinions = []

for text_opinion in provider.iter_from_opinion(opinion):
assert(isinstance(text_opinion, TextOpinion))
def ppl_text_ids_to_annotated(annotator, data_type, doc_ops, opin_ops):
assert(isinstance(annotator, BaseAnnotator))

if not filter_func(text_opinion):
continue

text_opinions.append(text_opinion)

if len(text_opinions) == 0:
continue
return [
# (id) -> (id, opinions)
MapPipelineItem(map_func=lambda doc_id: (
doc_id, annotator.annotate_collection(data_type=data_type,
doc_id=doc_id,
doc_ops=doc_ops,
opin_ops=opin_ops)))
]

linkage = TextOpinionsLinkage(text_opinions)

if tag_value_func is not None:
linkage.set_tag(tag_value_func(linkage))
def ppl_text_ids_to_parsed_news(parse_news_func, iter_doc_opins):
assert(callable(parse_news_func))
assert(callable(iter_doc_opins))

yield linkage
return [
# (id, opinions) -> (parsed_news, opinions).
MapPipelineItem(map_func=lambda data: (parse_news_func(data[0]), data[1])),
]


def text_opinions_iter_pipeline(parse_news_func, iter_doc_opins,
value_to_group_id_func, terms_per_context):
def ppl_parsed_news_to_opinion_linkages(value_to_group_id_func, terms_per_context):
""" Opinion collection generation pipeline.
NOTE: Here we do not perform IDs assignation!
"""
# TODO. #250, separate. Part 1. parameters.
assert(callable(parse_news_func))
assert(callable(iter_doc_opins))

# TODO. Part separate. Part 2. parameters.
assert(callable(value_to_group_id_func))
assert(isinstance(terms_per_context, int))

return BasePipeline([

### TODO. #250. Separate
### TODO. #250. Separate PART 1. (Related to ids -> (parsed_news, opinions)

# (id) -> (id, opinions)
MapPipelineItem(map_func=lambda doc_id: (doc_id, list(iter_doc_opins(doc_id)))),

# (id, opinions) -> (parsed_news, opinions).
MapPipelineItem(map_func=lambda data: (parse_news_func(data[0]), data[1])),

### TODO. Separate
### TODO. Separate PART 2. (parsed_news, opinions) -> linkages[]
return [

# (parsed_news, opinions) -> (opins_provider, entities_provider, opinions).
MapPipelineItem(map_func=lambda data: (
Expand All @@ -85,9 +61,35 @@ def text_opinions_iter_pipeline(parse_news_func, iter_doc_opins,
text_opinion=text_opinion,
window_size=terms_per_context))),

### TODO. #250. Separate
### TODO. #250. Separate Part 3. Flatten. linkage[] -> linkages.

# linkages[] -> linkages.
FlattenIterPipelineItem()
])
]


def __to_text_opinion_linkages(provider, opinions, tag_value_func, filter_func):
assert(isinstance(provider, TextOpinionPairsProvider))
assert(callable(tag_value_func))
assert(callable(filter_func))

for opinion in opinions:
assert(isinstance(opinion, Opinion))

text_opinions = []

for text_opinion in provider.iter_from_opinion(opinion):
assert(isinstance(text_opinion, TextOpinion))

if not filter_func(text_opinion):
continue

text_opinions.append(text_opinion)

if len(text_opinions) == 0:
continue

linkage = TextOpinionsLinkage(text_opinions)

if tag_value_func is not None:
linkage.set_tag(tag_value_func(linkage))

yield linkage

0 comments on commit 745d663

Please sign in to comment.