Skip to content

Commit

Permalink
#249 refactoring done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 28, 2022
1 parent 438f322 commit 34b8623
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 113 deletions.
3 changes: 1 addition & 2 deletions arekit/common/experiment/annot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def _annot_collection_core(self, parsed_news, data_type, opin_ops):

# region public methods

def annotate_collection(self, data_type, doc_id, doc_ops, opin_ops):
parsed_news = doc_ops.parse_doc(doc_id)
def annotate_collection(self, data_type, parsed_news, opin_ops):
return parsed_news.RelatedDocID, \
self._annot_collection_core(parsed_news=parsed_news, data_type=data_type, opin_ops=opin_ops)

Expand Down
17 changes: 1 addition & 16 deletions arekit/common/experiment/api/ops_doc.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.news.parser import NewsParser
from arekit.common.text.parser import BaseTextParser


class DocumentOperations(object):
"""
Provides operations with documents
"""

def __init__(self, exp_ctx, text_parser=None):
def __init__(self, exp_ctx):
assert(isinstance(exp_ctx, ExperimentContext) or exp_ctx is None)
assert(isinstance(text_parser, BaseTextParser) or text_parser is None)
self._exp_ctx = exp_ctx
self.__text_parser = text_parser

# region abstract methods

Expand Down Expand Up @@ -40,15 +36,4 @@ def iter_doc_ids(self, data_type):
for doc_id in data_types_splits[data_type]:
yield doc_id

def parse_doc(self, doc_id):
return self.__parse_doc(doc_id=doc_id)

# endregion

# region private methods

def __parse_doc(self, doc_id):
news = self.get_doc(doc_id=doc_id)
return NewsParser.parse(news=news, text_parser=self.__text_parser)

# endregion
19 changes: 13 additions & 6 deletions arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,26 @@
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.news.parser import NewsParser
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.samplers.factory import create_bert_sample_provider
from arekit.contrib.utils.pipeline import ppl_text_ids_to_parsed_news, ppl_parsed_news_to_opinion_linkages, \
ppl_text_ids_to_annotated
ppl_parsed_to_annotation


class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):

def __init__(self, exp_io, exp_ctx, doc_ops, opin_ops,
sample_labels_fmt, annot_labels_fmt, value_to_group_id_func,
sample_provider_type, entity_formatter, balance_train_samples):
sample_provider_type, entity_formatter, balance_train_samples,
text_parser):
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(sample_labels_fmt, StringLabelsFormatter))
assert(isinstance(annot_labels_fmt, StringLabelsFormatter))
assert(isinstance(text_parser, BaseTextParser))
assert(callable(value_to_group_id_func))
super(BertExperimentInputSerializerIterationHandler, self).__init__()

Expand All @@ -40,6 +44,7 @@ def __init__(self, exp_io, exp_ctx, doc_ops, opin_ops,
self.__exp_ctx = exp_ctx
self.__doc_ops = doc_ops
self.__opin_ops = opin_ops
self.__text_parser = text_parser

# region private methods

Expand All @@ -64,14 +69,16 @@ def __handle_iteration(self, data_type):
storage=BaseRowsStorage())

pipeline = BasePipeline(
ppl_text_ids_to_annotated(
ppl_text_ids_to_parsed_news(
parse_news_func=lambda doc_id: NewsParser.parse(
news=self.__doc_ops.get_doc(doc_id),
text_parser=self.__text_parser))
+
ppl_parsed_to_annotation(
annotator=self.__exp_ctx.Annotator,
data_type=data_type,
doc_ops=self.__doc_ops,
opin_ops=self.__opin_ops)
+
ppl_text_ids_to_parsed_news(parse_news_func=lambda doc_id: self.__doc_ops.parse_doc(doc_id))
+
ppl_parsed_news_to_opinion_linkages(
value_to_group_id_func=self.__value_to_group_id_func,
terms_per_context=self.__exp_ctx.TermsPerContext)
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/annot/two_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def _annot_collection_core(self, parsed_news, data_type, opin_ops):

# region public methods

def annotate_collection(self, data_type, doc_id, doc_ops, opin_ops):
def annotate_collection(self, data_type, parsed_news, opin_ops):

if data_type == DataType.Train:
# Return empty collection.
return opin_ops.create_opinion_collection()

super(TwoScaleTaskAnnotator, self).annotate_collection(
data_type, doc_ops=doc_ops, doc_id=doc_id, opin_ops=opin_ops)
data_type, parsed_news=parsed_news, opin_ops=opin_ops)

# endregion
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_ds/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

class RuAttitudesDocumentOperations(DocumentOperations):

def __init__(self, exp_ctx, text_parser, ru_attitudes):
def __init__(self, exp_ctx, ru_attitudes):
assert(isinstance(ru_attitudes, dict))
super(RuAttitudesDocumentOperations, self).__init__(exp_ctx=exp_ctx, text_parser=text_parser)
super(RuAttitudesDocumentOperations, self).__init__(exp_ctx=exp_ctx)
self.__ru_attitudes = ru_attitudes

# region DocumentOperations
Expand Down
13 changes: 3 additions & 10 deletions arekit/contrib/experiment_rusentrel/exp_ds/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,23 @@
logging.basicConfig(level=logging.INFO)


def create_ruattitudes_experiment(exp_ctx, exp_io, version, load_docs, ra_doc_ids_func, ppl_items):
def create_ruattitudes_experiment(exp_ctx, exp_io, version, load_docs, ra_doc_ids_func):
""" Application of distant supervision, especially for pretraining purposes.
Suggested to utilize with a large RuAttitudes-format collections (v2.0-large).
"""
assert(isinstance(version, RuAttitudesVersions))
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(load_docs, bool))
assert(isinstance(ppl_items, list) or ppl_items is None)

ru_attitudes = read_ruattitudes_in_memory(version=version,
doc_id_func=ra_doc_ids_func,
keep_doc_ids_only=not load_docs)

text_parser = create_text_parser(exp_ctx=exp_ctx,
entities_parser=RuAttitudesTextEntitiesParser(),
value_to_group_id_func=None,
ppl_items=ppl_items)

logger.info("Create document operations ...")
doc_ops = RuAttitudesDocumentOperations(exp_ctx=exp_ctx,
ru_attitudes=ru_attitudes,
text_parser=text_parser)
ru_attitudes=ru_attitudes)

logger.info("Create opinion operations ...")
opin_ops = RuAttitudesOpinionOperations(ru_attitudes=ru_attitudes)

return BaseExperiment(exp_ctx=exp_ctx, exp_io=exp_io, opin_ops=opin_ops, doc_ops=doc_ops)
return BaseExperiment(exp_ctx=exp_ctx, exp_io=exp_io, opin_ops=opin_ops, doc_ops=doc_ops)
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_joined/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class RuSentrelWithRuAttitudesDocumentOperations(DocumentOperations):

def __init__(self, rusentrel_doc_ids, rusentrel_doc, get_ruattitudes_doc, text_parser):
def __init__(self, rusentrel_doc_ids, rusentrel_doc, get_ruattitudes_doc):
assert(isinstance(rusentrel_doc_ids, set))
assert(isinstance(rusentrel_doc, RuSentrelDocumentOperations))
assert(callable(get_ruattitudes_doc))
Expand All @@ -15,7 +15,7 @@ def __init__(self, rusentrel_doc_ids, rusentrel_doc, get_ruattitudes_doc, text_p
# RuAttitude data-folding considered as `auxiliary`.
super(RuSentrelWithRuAttitudesDocumentOperations, self).__init__(
# Note: remporary hack in terms of exp_ctx == None.
exp_ctx=None, text_parser=text_parser)
exp_ctx=None)

self.__rusentrel_doc = rusentrel_doc
self.__rusentrel_doc_ids = rusentrel_doc_ids
Expand Down
35 changes: 6 additions & 29 deletions arekit/contrib/experiment_rusentrel/exp_joined/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@


def create_rusentrel_with_ruattitudes_expriment(exp_ctx, exp_io, folding_type, ra_doc_id_func,
ruattitudes_version, rusentrel_version, load_docs,
ppl_items):
ruattitudes_version, rusentrel_version, load_docs):
"""
IO for the experiment with distant supervision for sentiment attitude extraction task.
Original Paper (RuAttitudes-1.0): https://www.aclweb.org/anthology/R19-1118/
Expand All @@ -33,29 +32,18 @@ def create_rusentrel_with_ruattitudes_expriment(exp_ctx, exp_io, folding_type, r
assert(isinstance(rusentrel_version, RuSentRelVersions))
assert(isinstance(folding_type, FoldingType))
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(ppl_items, list) or ppl_items is None)
assert(callable(ra_doc_id_func))

optional_data = OptnionalDataProvider(exp_ctx=exp_ctx,
ruattitudes_version=ruattitudes_version,
rusentrel_version=rusentrel_version,
load_docs=load_docs,
ra_doc_id_func=ra_doc_id_func,
ppl_items=ppl_items)

# init text parser.
# TODO. Limitation, depending on document, entities parser may vary.
text_parser = create_text_parser(
exp_ctx=exp_ctx,
entities_parser=BratTextEntitiesParser(),
value_to_group_id_func=optional_data.get_synonyms().get_synonym_group_index,
ppl_items=ppl_items)
ra_doc_id_func=ra_doc_id_func)

# init documents.
rusentrel_doc = RuSentrelDocumentOperations(exp_ctx=exp_ctx,
version=rusentrel_version,
get_synonyms_func=optional_data.get_synonyms,
text_parser=text_parser)
get_synonyms_func=optional_data.get_synonyms)

# Init opinions
rusentrel_op = RuSentrelOpinionOperations(exp_ctx=exp_ctx,
Expand All @@ -69,8 +57,7 @@ def create_rusentrel_with_ruattitudes_expriment(exp_ctx, exp_io, folding_type, r
doc_ops = RuSentrelWithRuAttitudesDocumentOperations(
rusentrel_doc_ids=set(all_rusentrel_doc_ids),
rusentrel_doc=rusentrel_doc,
get_ruattitudes_doc=optional_data.get_or_load_ruattitudes_doc_ops,
text_parser=text_parser)
get_ruattitudes_doc=optional_data.get_or_load_ruattitudes_doc_ops)

opin_ops = RuSentrelWithRuAttitudesOpinionOperations(
rusentrel_op=rusentrel_op,
Expand All @@ -83,9 +70,8 @@ def create_rusentrel_with_ruattitudes_expriment(exp_ctx, exp_io, folding_type, r

class OptnionalDataProvider(object):

def __init__(self, exp_ctx, ruattitudes_version, rusentrel_version, load_docs, ra_doc_id_func, ppl_items):
def __init__(self, exp_ctx, ruattitudes_version, rusentrel_version, load_docs, ra_doc_id_func):
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(ppl_items, list) or ppl_items is None)
self.__exp_ctx = exp_ctx
self.__load_docs = load_docs
self.__ruattitudes_version = ruattitudes_version
Expand Down Expand Up @@ -119,15 +105,6 @@ def __init_ruattitudes_and_doc_ops(self):
doc_id_func=self.__ra_doc_id_func,
keep_doc_ids_only=not self.__load_docs)

text_parser = create_text_parser(
exp_ctx=self.__exp_ctx,
entities_parser=RuAttitudesTextEntitiesParser(),
value_to_group_id_func=self.__synonyms_provider.get_or_load_synonyms_collection().get_synonym_group_index,
ppl_items=self.__ppl_items)

# Completing initialization.
self.__ruattitudes_doc = RuAttitudesDocumentOperations(
exp_ctx=self.__exp_ctx,
ru_attitudes=ru_attitudes,
text_parser=text_parser)
self.__ruattitudes_doc = RuAttitudesDocumentOperations(exp_ctx=self.__exp_ctx, ru_attitudes=ru_attitudes)
self.__ru_attitudes = ru_attitudes
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_sl/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ class RuSentrelDocumentOperations(DocumentOperations):
Limitations: Supported only train/test collections format
"""

def __init__(self, exp_ctx, text_parser, version, get_synonyms_func):
def __init__(self, exp_ctx, version, get_synonyms_func):
assert(isinstance(version, RuSentRelVersions))
assert(callable(get_synonyms_func))
super(RuSentrelDocumentOperations, self).__init__(exp_ctx=exp_ctx, text_parser=text_parser)
super(RuSentrelDocumentOperations, self).__init__(exp_ctx=exp_ctx)

self.__version = version
self.__get_synonyms_func = get_synonyms_func
Expand Down
12 changes: 1 addition & 11 deletions arekit/contrib/experiment_rusentrel/exp_sl/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.folding.types import FoldingType
from arekit.contrib.experiment_rusentrel import common
from arekit.contrib.experiment_rusentrel.common import create_text_parser
from arekit.contrib.experiment_rusentrel.exp_sl.documents import RuSentrelDocumentOperations
from arekit.contrib.experiment_rusentrel.exp_sl.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def create_rusentrel_experiment(exp_ctx, exp_io, version, folding_type, ppl_items):
def create_rusentrel_experiment(exp_ctx, exp_io, version, folding_type):
"""
Represents a cv-based experiment over RuSentRel collection,
which supports train/test separation.
Expand All @@ -26,7 +24,6 @@ def create_rusentrel_experiment(exp_ctx, exp_io, version, folding_type, ppl_item
assert(isinstance(version, RuSentRelVersions))
assert(isinstance(folding_type, FoldingType))
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(ppl_items, list) or ppl_items is None)

synonyms_provider = OptionalSynonymsProvider(version)

Expand All @@ -36,15 +33,8 @@ def create_rusentrel_experiment(exp_ctx, exp_io, version, folding_type, ppl_item
exp_io=exp_io,
get_synonyms_func=synonyms_provider.get_or_load_synonyms_collection)

text_parser = create_text_parser(
exp_ctx=exp_ctx,
entities_parser=BratTextEntitiesParser(),
value_to_group_id_func=synonyms_provider.get_or_load_synonyms_collection().get_synonym_group_index,
ppl_items=ppl_items)

doc_ops = RuSentrelDocumentOperations(exp_ctx=exp_ctx,
version=version,
text_parser=text_parser,
get_synonyms_func=synonyms_provider.get_or_load_synonyms_collection)

return BaseExperiment(exp_ctx=exp_ctx, exp_io=exp_io, doc_ops=doc_ops, opin_ops=opin_ops)
Expand Down
Loading

0 comments on commit 34b8623

Please sign in to comment.