Skip to content

Commit

Permalink
#212 Refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Oct 20, 2021
1 parent ebb3d15 commit d0a6a74
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 40 deletions.
3 changes: 2 additions & 1 deletion arekit/common/experiment/annot/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.news.parsed.base import ParsedNews
Expand Down Expand Up @@ -29,7 +30,7 @@ def __iter_annotated_collections(self, data_type, doc_ops, opin_ops):
assert(isinstance(opin_ops, OpinionOperations))

logged_parsed_news_iter = progress_bar_iter(
iterable=doc_ops.iter_parsed_docs(doc_ops.iter_doc_ids_to_annotate()),
iterable=doc_ops.iter_parsed_docs(doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Annotate)),
desc="Annotating parsed news [{}]".format(data_type))

for parsed_news in logged_parsed_news_iter:
Expand Down
4 changes: 2 additions & 2 deletions arekit/common/experiment/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from arekit.common.evaluation.utils import OpinionCollectionsToCompareUtils
from arekit.common.experiment.api.ctx_base import DataIO
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
Expand Down Expand Up @@ -95,8 +96,7 @@ def evaluate(self, data_type, epoch_index):
assert(isinstance(self.__experiment_data, TrainingData))

# Extracting all docs to cmp and those that is related to data_type.
# TODO. 212. Pass tag ("compare")
cmp_doc_ids_iter = self.__doc_ops.iter_doc_ids_to_compare()
cmp_doc_ids_iter = self.__doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Compare)
doc_ids_iter = self.__doc_ops.iter_doc_ids(data_type=data_type)
cmp_doc_ids_set = set(cmp_doc_ids_iter)

Expand Down
14 changes: 14 additions & 0 deletions arekit/common/experiment/api/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum


class BaseDocumentTag(Enum):

""" Denotes a document that utilized by annotator algorithm in order to
provide the related labeling of annotated attitudes in it.
By default, we consider an empty set, so there is no need to utilize annotator.
"""
Annotate = 1

""" Denotes a document that utilized in model evaluation process
"""
Compare = 2
13 changes: 2 additions & 11 deletions arekit/common/experiment/api/ops_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,8 @@ def DataFolding(self):
def get_doc(self, doc_id):
raise NotImplementedError()

# TODO. 212. Unify, add tag.
def iter_doc_ids_to_annotate(self):
""" provides set of documents that utilized by annotator algorithm in order to
provide the related labeling of annotated attitudes in it.
By default, we consider an empty set, so there is no need to utilize annotator.
"""
raise NotImplementedError()

# TODO. 212. Unify, add tag.
def iter_doc_ids_to_compare(self):
""" provides set of documents that utilized in model evaluation process
def iter_tagget_doc_ids(self, tag):
""" Document identifiers which are grouped by a particular tag.
"""
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.engine import ExperimentEngine
from arekit.common.linked.helper import create_and_fill_opinion_collection
from arekit.common.labels.scaler import BaseLabelScaler
Expand Down Expand Up @@ -86,8 +87,7 @@ def _handle_iteration(self, iter_index):

# TODO. This should be removed as this is a part of the particular
# experiment, not source!.
# TODO. 212. Pass tag ("compare")
cmp_doc_ids_set = set(self._experiment.DocumentOperations.iter_doc_ids_to_compare())
cmp_doc_ids_set = set(self._experiment.DocumentOperations.iter_tagget_doc_ids(BaseDocumentTag.Compare))

if callback.check_log_exists():
self._log_info("Skipping [Log file already exist]")
Expand Down
6 changes: 4 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_ds/documents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from arekit.common.experiment.api.ctx_base import DataIO
from arekit.common.experiment.api.ctx_serialization import SerializationData
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.contrib.source.ruattitudes.news.parse_options import RuAttitudesParseOptions

Expand Down Expand Up @@ -27,8 +28,9 @@ def _create_parse_options(self):
return RuAttitudesParseOptions(stemmer=self.__exp_data.Stemmer,
frame_variants_collection=self.__exp_data.FrameVariantCollection)

# TODO. 212. Rename, add tag parameter.
def iter_doc_ids_to_annotate(self):
def iter_tagget_doc_ids(self, tag):
assert(isinstance(tag, BaseDocumentTag))
assert(tag == BaseDocumentTag.Annotate or tag == BaseDocumentTag.Compare)
return
yield

Expand Down
9 changes: 2 additions & 7 deletions arekit/contrib/experiment_rusentrel/exp_joined/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,7 @@ def _create_parse_options(self):
# Therefore we provide rusentrel_doc by default.
return self.__rusentrel_doc._create_parse_options()

# TODO. 212. Unify (use only one method for both below)
def iter_doc_ids_to_annotate(self):
return self.__rusentrel_doc.iter_doc_ids_to_annotate()

# TODO. 212. Unify (use only one method for this and one above)
def iter_doc_ids_to_compare(self):
return self.__rusentrel_doc.iter_doc_ids_to_compare()
def iter_tagget_doc_ids(self, tag):
return self.__rusentrel_doc.iter_tagget_doc_ids(tag)

# endregion
1 change: 0 additions & 1 deletion arekit/contrib/experiment_rusentrel/exp_joined/opinions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.contrib.experiment_rusentrel.exp_sl.opinions import RuSentrelOpinionOperations


Expand Down
10 changes: 4 additions & 6 deletions arekit/contrib/experiment_rusentrel/exp_sl/documents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from arekit.common.experiment.api.ctx_base import DataIO
from arekit.common.experiment.api.ctx_serialization import SerializationData
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.news.base import RuSentRelNews
Expand All @@ -23,12 +24,9 @@ def __init__(self, exp_data, folding, version, get_synonyms_func):

# region DocumentOperations

# TODO. 212. Unify with the one below, add tag.
def iter_doc_ids_to_annotate(self):
return self.DataFolding.iter_doc_ids()

# TODO. 212. Unify with the one above, add tag.
def iter_doc_ids_to_compare(self):
def iter_tagget_doc_ids(self, tag):
assert(isinstance(tag, BaseDocumentTag))
assert(tag == BaseDocumentTag.Compare or tag == BaseDocumentTag.Annotate)
return self.DataFolding.iter_doc_ids()

def get_doc(self, doc_id):
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from arekit.common.data import const
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
Expand Down Expand Up @@ -103,8 +104,7 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc
assert(isinstance(label_calc_mode, LabelCalculationMode))
assert(isinstance(labels_formatter, StringLabelsFormatter))

# TODO. 212. Pass tag ("compare")
cmp_doc_ids_set = set(doc_ops.iter_doc_ids_to_compare())
cmp_doc_ids_set = set(doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Compare))

output_view = MulticlassOutputView(labels_scaler=labels_scaler,
storage=output_storage)
Expand Down
10 changes: 4 additions & 6 deletions examples/network/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
Expand All @@ -7,14 +8,11 @@ class SingleDocOperations(DocumentOperations):
""" Operations over a single document for inference.
"""

# TODO. 212. Rename, add tag.
def iter_doc_ids_to_annotate(self):
def iter_tagget_doc_ids(self, tag):
assert(isinstance(tag, BaseDocumentTag))
assert(tag == BaseDocumentTag.Annotate)
return 0

# TODO. 212. Remove (we don't need it in such case).
def iter_doc_ids_to_compare(self):
raise NotImplementedError()

def __init__(self, news):
folding = NoFolding(doc_ids_to_fold=[0], supported_data_types=[DataType.Test])
super(SingleDocOperations, self).__init__(folding)
Expand Down

0 comments on commit d0a6a74

Please sign in to comment.