Skip to content

Commit

Permalink
#262 done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 27, 2022
1 parent 285c76c commit 618d3a7
Show file tree
Hide file tree
Showing 17 changed files with 297 additions and 281 deletions.
66 changes: 0 additions & 66 deletions arekit/common/experiment/api/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import logging
from arekit.common.evaluation.evaluators.base import BaseEvaluator
from arekit.common.evaluation.results.base import BaseEvalResult
from arekit.common.evaluation.utils import OpinionCollectionsToCompareUtils
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.ctx_training import ExperimentTrainingContext
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.utils import progress_bar_iter

logger = logging.getLogger(__name__)


class BaseExperiment(object):
Expand All @@ -34,8 +24,6 @@ def ExperimentContext(self):

@property
def ExperimentIO(self):
""" Filepaths, related to experiment
"""
return self.__exp_io

@property
Expand All @@ -47,57 +35,3 @@ def DocumentOperations(self):
return self.__doc_ops

# endregion

def _init_log_flag(self, do_log):
assert(isinstance(do_log, bool))
self._do_log = do_log

def log_info(self, message, forced=False):
assert (isinstance(message, str))
if not self._do_log and not forced:
return
logger.info(message)

def evaluate(self, data_type, epoch_index):
"""
Perform experiment evaluation (related model) of a certain
`data_type` at certain `epoch_index`
data_type: DataType
used as data source (for document ids)
epoch_index: int or None
NOTE: assumes that results already written and converted in doc-level opinions.
"""
assert(isinstance(data_type, DataType))
assert(isinstance(epoch_index, int))
assert(isinstance(self.__exp_ctx, ExperimentTrainingContext))

# Extracting all docs to cmp and those that is related to data_type.
cmp_doc_ids_iter = self.__doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Compare)
doc_ids_iter = self.__doc_ops.iter_doc_ids(data_type=data_type)
cmp_doc_ids_set = set(cmp_doc_ids_iter)

# Compose cmp pairs iterator.
cmp_pairs_iter = OpinionCollectionsToCompareUtils.iter_comparable_collections(
doc_ids=[doc_id for doc_id in doc_ids_iter if doc_id in cmp_doc_ids_set],
read_etalon_collection_func=lambda doc_id: self.__opin_ops.get_etalon_opinion_collection(
doc_id=doc_id),
read_result_collection_func=lambda doc_id: self.__opin_ops.get_result_opinion_collection(
data_type=data_type,
doc_id=doc_id,
epoch_index=epoch_index))

# getting evaluator.
evaluator = self.__exp_ctx.Evaluator
assert(isinstance(evaluator, BaseEvaluator))

# evaluate every document.
logged_cmp_pairs_it = progress_bar_iter(cmp_pairs_iter, desc="Evaluate", unit='pairs')
result = evaluator.evaluate(cmp_pairs=logged_cmp_pairs_it)
assert(isinstance(result, BaseEvalResult))

# calculate results.
result.calculate()

return result
23 changes: 5 additions & 18 deletions arekit/common/experiment/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import logging

from arekit.common.experiment.api.base import BaseExperiment
from arekit.common.folding.base import BaseDataFolding


class ExperimentEngine(object):
Expand All @@ -9,22 +7,11 @@ class ExperimentEngine(object):
iteration in and runs handler during the latter.
"""

def __init__(self, experiment):
assert(isinstance(experiment, BaseExperiment))
self._experiment = experiment
self._logger = self.__create_logger()
def __init__(self, data_folding):
assert(isinstance(data_folding, BaseDataFolding))
self.__data_folding = data_folding
self.__handlers = None

@staticmethod
def __create_logger():
stream_handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(levelname)8s %(name)s | %(message)s')
stream_handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(stream_handler)
return logger

def __call_all_handlers(self, call_func):
assert(callable(call_func))

Expand Down Expand Up @@ -56,6 +43,6 @@ def run(self, handlers=None):
"""
self.__handlers = handlers
self._before_running()
for iter_index, _ in enumerate(self._experiment.ExperimentContext.DataFolding.iter_states()):
for iter_index, _ in enumerate(self.__data_folding.iter_states()):
self._handle_iteration(iter_index)
self._after_running()
7 changes: 0 additions & 7 deletions arekit/common/experiment/handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
from arekit.common.experiment.api.ctx_base import ExperimentContext


class ExperimentIterationHandler(object):

def __init__(self, exp_ctx):
assert(isinstance(exp_ctx, ExperimentContext))
self._exp_ctx = exp_ctx

def on_before_iteration(self):
pass

Expand Down
Empty file.
75 changes: 75 additions & 0 deletions arekit/common/experiment/handlers/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from arekit.common.evaluation.evaluators.base import BaseEvaluator
from arekit.common.evaluation.results.base import BaseEvalResult
from arekit.common.evaluation.utils import OpinionCollectionsToCompareUtils
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.ctx_training import ExperimentTrainingContext
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.utils import progress_bar_iter


class EvalIterationHandler(ExperimentIterationHandler):

def __init__(self, data_type, exp_ctx, doc_ops, opin_ops, epoch_indices):
assert(isinstance(data_type, DataType))
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(epoch_indices, list))

self.__data_type = data_type
self.__exp_ctx = exp_ctx
self.__doc_ops = doc_ops
self.__opin_ops = opin_ops
self.__epoch_indices = epoch_indices

def __evaluate(self, data_type, epoch_index):
"""
Perform experiment evaluation (related model) of a certain
`data_type` at certain `epoch_index`
data_type: DataType
used as data source (for document ids)
epoch_index: int or None
NOTE: assumes that results already written and converted in doc-level opinions.
"""
assert(isinstance(data_type, DataType))
assert(isinstance(epoch_index, int))
assert(isinstance(self.__exp_ctx, ExperimentTrainingContext))

# Extracting all docs to cmp and those that is related to data_type.
cmp_doc_ids_iter = self.__doc_ops.iter_tagget_doc_ids(BaseDocumentTag.Compare)
doc_ids_iter = self.__doc_ops.iter_doc_ids(data_type=data_type)
cmp_doc_ids_set = set(cmp_doc_ids_iter)

# Compose cmp pairs iterator.
cmp_pairs_iter = OpinionCollectionsToCompareUtils.iter_comparable_collections(
doc_ids=[doc_id for doc_id in doc_ids_iter if doc_id in cmp_doc_ids_set],
read_etalon_collection_func=lambda doc_id: self.__opin_ops.get_etalon_opinion_collection(
doc_id=doc_id),
read_result_collection_func=lambda doc_id: self.__opin_ops.get_result_opinion_collection(
data_type=data_type,
doc_id=doc_id,
epoch_index=epoch_index))

# getting evaluator.
evaluator = self.__exp_ctx.Evaluator
assert(isinstance(evaluator, BaseEvaluator))

# evaluate every document.
logged_cmp_pairs_it = progress_bar_iter(cmp_pairs_iter, desc="Evaluate", unit='pairs')
result = evaluator.evaluate(cmp_pairs=logged_cmp_pairs_it)
assert(isinstance(result, BaseEvalResult))

# calculate results.
result.calculate()

return result

def on_iteration(self, iter_index):
for epoch in self.__epoch_indices:
self.__evaluate(data_type=self.__data_type, epoch_index=epoch)
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,35 @@
from arekit.common.data.input.repositories.opinions import BaseInputOpinionsRepository
from arekit.common.data.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base import BaseExperiment
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.engine import ExperimentEngine
from arekit.common.experiment.handler import ExperimentIterationHandler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.contrib.bert.samplers.factory import create_bert_sample_provider


# TODO. 262. Refactor as handler (weird inheritance, limits capabilities).
class BertExperimentInputSerializer(ExperimentEngine):
class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):

def __init__(self, experiment,
labels_formatter,
skip_if_folder_exists,
sample_provider_type,
entity_formatter,
balance_train_samples):
assert(isinstance(experiment, BaseExperiment))
def __init__(self, exp_io, exp_ctx, doc_ops, opin_ops, labels_formatter, skip_if_folder_exists,
sample_provider_type, entity_formatter, balance_train_samples):
assert(isinstance(skip_if_folder_exists, bool))
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(labels_formatter, StringLabelsFormatter))
super(BertExperimentInputSerializer, self).__init__(experiment)
super(BertExperimentInputSerializerIterationHandler, self).__init__()

self.__skip_if_folder_exists = skip_if_folder_exists
self.__entity_formatter = entity_formatter
self.__sample_provider_type = sample_provider_type
self.__balance_train_samples = balance_train_samples
self.__labels_formatter = labels_formatter
self.__exp_io = exp_io
self.__exp_ctx = exp_ctx
self.__doc_ops = doc_ops
self.__opin_ops = opin_ops

# region private methods

Expand All @@ -41,7 +44,7 @@ def __handle_iteration(self, data_type):
sample_rows_provider = create_bert_sample_provider(
labels_formatter=self.__labels_formatter,
provider_type=self.__sample_provider_type,
label_scaler=self._experiment.ExperimentContext.LabelsScaler,
label_scaler=self.__exp_ctx.LabelsScaler,
entity_formatter=self.__entity_formatter)

# Create repositories
Expand All @@ -57,59 +60,55 @@ def __handle_iteration(self, data_type):
# Create opinion provider
opinion_provider = InputTextOpinionProvider.create(
value_to_group_id_func=None,
parse_news_func=lambda doc_id: self._experiment.DocumentOperations.parse_doc(doc_id),
iter_doc_opins=lambda doc_id:
self._experiment.OpinionOperations.iter_opinions_for_extraction(doc_id=doc_id, data_type=data_type),
terms_per_context=self._experiment.ExperimentContext.TermsPerContext)
parse_news_func=lambda doc_id: self.__doc_ops.parse_doc(doc_id),
iter_doc_opins=lambda doc_id: self.__opin_ops.iter_opinions_for_extraction(
doc_id=doc_id, data_type=data_type),
terms_per_context=self.__exp_ctx.TermsPerContext)

# Populate repositories
opinions_repo.populate(opinion_provider=opinion_provider,
doc_ids=list(self._experiment.DocumentOperations.iter_doc_ids(data_type)),
doc_ids=list(self.__doc_ops.iter_doc_ids(data_type)),
desc="opinion")

samples_repo.populate(opinion_provider=opinion_provider,
doc_ids=list(self._experiment.DocumentOperations.iter_doc_ids(data_type)),
doc_ids=list(self.__doc_ops.iter_doc_ids(data_type)),
desc="sample")

if self._experiment.ExperimentIO.balance_samples(data_type=data_type, balance=self.__balance_train_samples):
if self.__exp_io.balance_samples(data_type=data_type, balance=self.__balance_train_samples):
samples_repo.balance()

# Save repositories
samples_repo.write(
target=self._experiment.ExperimentIO.create_samples_writer_target(data_type),
writer=self._experiment.ExperimentIO.create_samples_writer())
target=self.__exp_io.create_samples_writer_target(data_type),
writer=self.__exp_io.create_samples_writer())

opinions_repo.write(
target=self._experiment.ExperimentIO.create_opinions_writer_target(data_type),
writer=self._experiment.ExperimentIO.create_opinions_writer())
target=self.__exp_io.create_opinions_writer_target(data_type),
writer=self.__exp_io.create_opinions_writer())

# endregion

# region protected methods

def _handle_iteration(self, it_index):
def on_iteration(self, iter_index):
""" Performing data serialization for a particular iteration
"""
for data_type in self._experiment.ExperimentContext.DataFolding.iter_supported_data_types():
for data_type in self.__exp_ctx.DataFolding.iter_supported_data_types():
self.__handle_iteration(data_type)

def _before_running(self):
self._logger.info("Perform annotation ...")
def on_before_iteration(self):
for data_type in self.__exp_ctx.DataFolding.iter_supported_data_types():

for data_type in self._experiment.ExperimentContext.DataFolding.iter_supported_data_types():

collections_it = self._experiment.ExperimentContext.Annotator.iter_annotated_collections(
data_type=data_type,
opin_ops=self._experiment.OpinionOperations,
doc_ops=self._experiment.DocumentOperations)
collections_it = self.__exp_ctx.Annotator.iter_annotated_collections(
data_type=data_type, opin_ops=self.__opin_ops, doc_ops=self.__doc_ops)

for doc_id, collection in collections_it:

target = self._experiment.ExperimentIO.create_opinion_collection_target(
target = self.__exp_io.create_opinion_collection_target(
doc_id=doc_id,
data_type=data_type)

self._experiment.write_opinion_collection(
self.__exp_io.write_opinion_collection(
collection=collection,
target=target,
labels_formatter=self.__labels_formatter)
Expand Down
Loading

0 comments on commit 618d3a7

Please sign in to comment.