Skip to content

Commit

Permalink
#320 related refactoring. Moving OpnionCollection into experiment_rus…
Browse files Browse the repository at this point in the history
…entrel
  • Loading branch information
nicolay-r committed Jul 12, 2022
1 parent 4a70069 commit 55eb5df
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 34 deletions.
9 changes: 1 addition & 8 deletions arekit/common/experiment/api/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations


class BaseExperiment(object):

def __init__(self, exp_ctx, exp_io, opin_ops, doc_ops):
def __init__(self, exp_ctx, exp_io, doc_ops):
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(doc_ops, DocumentOperations))
self.__exp_ctx = exp_ctx
self.__exp_io = exp_io
self.__opin_ops = opin_ops
self.__doc_ops = doc_ops

# region Properties
Expand All @@ -26,10 +23,6 @@ def ExperimentContext(self):
def ExperimentIO(self):
return self.__exp_io

@property
def OpinionOperations(self):
return self.__opin_ops

@property
def DocumentOperations(self):
return self.__doc_ops
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/experiment/api/ctx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def LabelsCount(self):
def Annotator(self):
""" Provides an instance of annotator that might be utilized
for attitudes labeling within a specific set of documents,
declared in a particular experiment (see OpinionOperations).
declared in a particular experiment.
"""
return self.__annot

Expand Down
14 changes: 14 additions & 0 deletions arekit/contrib/experiment_rusentrel/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from arekit.common.experiment.api.base import BaseExperiment
from arekit.contrib.experiment_rusentrel.ops_opin import OpinionOperations


class CustomExperiment(BaseExperiment):

def __init__(self, exp_ctx, exp_io, doc_ops, opin_ops):
assert(isinstance(opin_ops, OpinionOperations))
super(CustomExperiment, self).__init__(exp_ctx=exp_ctx, exp_io=exp_io, doc_ops=doc_ops)
self.__opin_ops = opin_ops

@property
def OpinionOperations(self):
return self.__opin_ops
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_ds/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from arekit.common.experiment.api.base import BaseExperiment
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.contrib.experiment_rusentrel.base import CustomExperiment
from arekit.contrib.experiment_rusentrel.exp_ds.documents import RuAttitudesDocumentOperations
from arekit.contrib.experiment_rusentrel.exp_ds.opinions import RuAttitudesOpinionOperations
from arekit.contrib.experiment_rusentrel.exp_ds.utils import read_ruattitudes_in_memory
Expand Down Expand Up @@ -30,4 +30,4 @@ def create_ruattitudes_experiment(exp_ctx, exp_io, version, load_docs, ra_doc_id
logger.info("Create opinion operations ...")
opin_ops = RuAttitudesOpinionOperations(ru_attitudes=ru_attitudes)

return BaseExperiment(exp_ctx=exp_ctx, exp_io=exp_io, opin_ops=opin_ops, doc_ops=doc_ops)
return CustomExperiment(exp_ctx=exp_ctx, exp_io=exp_io, opin_ops=opin_ops, doc_ops=doc_ops)
2 changes: 1 addition & 1 deletion arekit/contrib/experiment_rusentrel/exp_ds/opinions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.contrib.experiment_rusentrel.ops_opin import OpinionOperations
from arekit.contrib.experiment_rusentrel.labels.scalers.ruattitudes import ExperimentRuAttitudesLabelConverter
from arekit.contrib.source.ruattitudes.opinions.utils import RuAttitudesSentenceOpinionUtils

Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/experiment_rusentrel/exp_sl/factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging

from arekit.common.experiment.api.base import BaseExperiment
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.folding.types import FoldingType
from arekit.contrib.experiment_rusentrel import common
from arekit.contrib.experiment_rusentrel.base import CustomExperiment
from arekit.contrib.experiment_rusentrel.exp_sl.documents import RuSentrelDocumentOperations
from arekit.contrib.experiment_rusentrel.exp_sl.opinions import RuSentrelOpinionOperations
from arekit.contrib.experiment_rusentrel.synonyms.provider import RuSentRelSynonymsCollectionProvider
Expand Down Expand Up @@ -37,7 +37,7 @@ def create_rusentrel_experiment(exp_ctx, exp_io, version, folding_type):
version=version,
get_synonyms_func=synonyms_provider.get_or_load_synonyms_collection)

return BaseExperiment(exp_ctx=exp_ctx, exp_io=exp_io, doc_ops=doc_ops, opin_ops=opin_ops)
return CustomExperiment(exp_ctx=exp_ctx, exp_io=exp_io, doc_ops=doc_ops, opin_ops=opin_ops)


class OptionalSynonymsProvider(object):
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/experiment_rusentrel/exp_sl/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.opinions.collection import OpinionCollection
from arekit.contrib.experiment_rusentrel.labels.formatters.neut_label import ExperimentNeutralLabelsFormatter
from arekit.contrib.experiment_rusentrel.labels.formatters.rusentrel import RuSentRelExperimentLabelsFormatter
from arekit.contrib.experiment_rusentrel.ops_opin import OpinionOperations
from arekit.contrib.source.rusentrel.io_utils import RuSentRelVersions
from arekit.contrib.source.rusentrel.opinions.collection import RuSentRelOpinionCollection

Expand Down
File renamed without changes.
24 changes: 11 additions & 13 deletions arekit/contrib/utils/handlers/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from arekit.common.evaluation.result import BaseEvalResult
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.utils import progress_bar_iter
Expand All @@ -18,19 +17,24 @@ class EvalIterationHandler(ExperimentIterationHandler):
различных оценок в виде отдельных функций.
"""

def __init__(self, data_type, doc_ops, opin_ops, epoch_indices, evaluator):
def __init__(self, data_type, doc_ops, epoch_indices, evaluator,
get_test_doc_collection_func, get_etalon_doc_collection_func):
""" get_doc_collection_func: func
(doc_id) -> collection (Any type)
"""
assert(isinstance(data_type, DataType))
assert(isinstance(doc_ops, DocumentOperations))
# TODO. #355 related. OpinionOperations limit this onto `Opinion` type only.
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(epoch_indices, list))
assert(isinstance(evaluator, BaseEvaluator))
assert(callable(get_test_doc_collection_func))
assert(callable(get_etalon_doc_collection_func))

self.__data_type = data_type
self.__doc_ops = doc_ops
self.__opin_ops = opin_ops
self.__epoch_indices = epoch_indices
self.__evaluator = evaluator
self.__get_test_doc_collection_func = get_test_doc_collection_func
self.__get_etalon_doc_collection_func = get_etalon_doc_collection_func

def __evaluate(self, data_type, epoch_index):
"""
Expand All @@ -54,14 +58,8 @@ def __evaluate(self, data_type, epoch_index):
# Compose cmp pairs iterator.
cmp_pairs_iter = DataPairsIterators.iter_func_based_collections(
doc_ids=[doc_id for doc_id in doc_ids_iter if doc_id in cmp_doc_ids_set],
# TODO. #355 related. OpinionOperations limit this onto `Opinion` type only.
read_etalon_collection_func=lambda doc_id: self.__opin_ops.get_etalon_opinion_collection(
doc_id=doc_id),
# TODO. #355 related. OpinionOperations limit this onto `Opinion` type only.
read_test_collection_func=lambda doc_id: self.__opin_ops.get_result_opinion_collection(
data_type=data_type,
doc_id=doc_id,
epoch_index=epoch_index))
read_etalon_collection_func=lambda doc_id: self.__get_test_doc_collection_func(doc_id),
read_test_collection_func=lambda doc_id: self.__get_etalon_doc_collection_func(doc_id))

# evaluate every document.
logged_cmp_pairs_it = progress_bar_iter(cmp_pairs_iter, desc="Evaluate", unit='pairs')
Expand Down
14 changes: 8 additions & 6 deletions arekit/contrib/utils/handlers/to_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from arekit.common.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.labels.scaler.base import BaseLabelScaler
Expand All @@ -16,21 +15,25 @@

class BaseOutputConverterIterationHandler(ExperimentIterationHandler):

def __init__(self, exp_io, doc_ops, opin_ops, data_type, label_scaler, labels_formatter):
def __init__(self, exp_io, doc_ops, create_opinion_collection_func,
data_type, label_scaler, labels_formatter):
""" create_opinion_collection_func: func
func () -> OpinionCollection (empty)
"""
assert(isinstance(exp_io, DefaultBertIOUtils))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(data_type, DataType))
assert(isinstance(label_scaler, BaseLabelScaler))
assert(isinstance(labels_formatter, StringLabelsFormatter))
assert(callable(create_opinion_collection_func))
super(BaseOutputConverterIterationHandler, self).__init__(exp_io=exp_io)
self._data_type = data_type

self.__exp_io = exp_io
self.__doc_ops = doc_ops
self.__opin_ops = opin_ops
self.__labels_formatter = labels_formatter
self.__label_scaler = label_scaler
self.__create_opinion_collection_func = create_opinion_collection_func

def __convert(self, output_storage, target_func):
""" From `output_storage` to `target` conversion.
Expand All @@ -52,8 +55,7 @@ def __convert(self, output_storage, target_func):
doc_id=doc_id,
opinions_view=self.__exp_io.create_opinions_view(self._data_type)),
doc_ids_set=cmp_doc_ids_set,
# TODO: #320 related. Create a separate parameter.
create_opinion_collection_func=self.__opin_ops.create_opinion_collection,
create_opinion_collection_func=self.__create_opinion_collection_func,
labels_scaler=self.__label_scaler,
label_calc_mode=LabelCalculationMode.AVERAGE)

Expand Down

0 comments on commit 55eb5df

Please sign in to comment.