Skip to content

Commit

Permalink
Refactoring Output opinion provider (related to #198)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Oct 2, 2021
1 parent 3050956 commit 6c383d1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 65 deletions.
57 changes: 7 additions & 50 deletions arekit/common/experiment/output/opinions/converter.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,22 @@
from arekit.common.experiment.input.views.opinions import BaseOpinionStorageView
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.experiment.output.views.base import BaseOutputView
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.opinions.base import Opinion


class OutputToOpinionCollectionsConverter(object):

# TODO. To output_view. Provide opinions_iter for collection organization only!
@staticmethod
def iter_opinion_collections(opinions_view,
# TODO. Remove
labels_scaler,
keep_doc_id_func,
# TODO. Collection???
create_opinion_collection_func,
# TODO. Remove
label_calculation_mode,
# TODO. Remove
supported_labels,
output_formatter):
assert(callable(keep_doc_id_func))
# TODO. Only for collection filling (remove)
assert(isinstance(labels_scaler, BaseLabelScaler))
def iter_opinion_collections(output_view, opinions_view, keep_doc_id_func, to_collection_func):
assert(isinstance(output_view, BaseOutputView))
assert(isinstance(opinions_view, BaseOpinionStorageView))
# TODO. Collection???
assert(callable(create_opinion_collection_func))
# TODO. Only for collection filling (remove)
assert(isinstance(label_calculation_mode, LabelCalculationMode))
# TODO. Only for collection filling (remove)
assert(isinstance(supported_labels, set) or supported_labels is None)
assert(isinstance(output_formatter, BaseOutputView))

labels_helper = SingleLabelsHelper(labels_scaler)
assert(callable(keep_doc_id_func))
assert(callable(to_collection_func))

for news_id in output_formatter.iter_news_ids():
for news_id in output_view.iter_news_ids():

if not keep_doc_id_func(news_id):
continue

# TODO. Collection???
collection = create_opinion_collection_func()

linked_iter = output_formatter.iter_linked_opinions(news_id=news_id,
linked_data_iter = output_view.iter_linked_opinions(news_id=news_id,
opinions_view=opinions_view)

fill_opinion_collection(
collection=collection,
linked_data_iter=linked_iter,
labels_helper=labels_helper,
to_opinion_func=OutputToOpinionCollectionsConverter.__to_label,
label_calc_mode=label_calculation_mode,
supported_labels=supported_labels)

yield news_id, collection

@staticmethod
def __to_label(item, label):
assert(isinstance(item, Opinion))
return Opinion(source_value=item.SourceValue,
target_value=item.TargetValue,
sentiment=label)
yield news_id, to_collection_func(linked_data_iter)
8 changes: 5 additions & 3 deletions arekit/common/experiment/output/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
from arekit.common.linked.data import LinkedDataWrapper
from arekit.common.model.labeling.base import LabelsHelper
from arekit.common.opinions.base import Opinion
from arekit.common.opinions.collection import OpinionCollection


def fill_opinion_collection(collection, linked_data_iter, labels_helper, to_opinion_func,
def fill_opinion_collection(create_opinion_collection, linked_data_iter,
labels_helper, to_opinion_func,
label_calc_mode, supported_labels=None):
""" to_opinion_func: (item, label) -> opinion
"""
assert(isinstance(collection, OpinionCollection))
assert(callable(create_opinion_collection))
assert(isinstance(linked_data_iter, collections.Iterable))
assert(isinstance(labels_helper, LabelsHelper))
assert(callable(to_opinion_func))
assert(isinstance(supported_labels, set) or supported_labels is None)

collection = create_opinion_collection()

for linked in linked_data_iter:
assert(isinstance(linked, LinkedDataWrapper))

Expand Down
30 changes: 24 additions & 6 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.engine import ExperimentEngine
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.experiment.output.views.multiple import MulticlassOutputView
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.opinions.base import Opinion
from arekit.common.utils import join_dir_with_subfolder_name
from arekit.contrib.bert.callback import Callback
from arekit.contrib.bert.output.eval_helper import EvalHelper
Expand Down Expand Up @@ -114,19 +117,18 @@ def _handle_iteration(self, iter_index):

# We utilize google bert format, where every row
# consist of label probabilities per every class
output = MulticlassOutputView(
output_view = MulticlassOutputView(
labels_scaler=self.__label_scaler,
storage=storage)

# iterate opinion collections.
collections_iter = OutputToOpinionCollectionsConverter.iter_opinion_collections(
output_view=output_view,
opinions_view=exp_io.create_opinions_view(self.__data_type),
labels_scaler=self.__label_scaler,
create_opinion_collection_func=self._experiment.OpinionOperations.create_opinion_collection,
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
label_calculation_mode=LabelCalculationMode.AVERAGE,
supported_labels=exp_data.SupportedCollectionLabels,
output_formatter=output)
to_collection_func=lambda linked_iter: self.__create_opinion_collection(
supported_labels=exp_data.SupportedCollectionLabels,
linked_iter=linked_iter))

for doc_id, collection in collections_iter:

Expand Down Expand Up @@ -160,3 +162,19 @@ def _before_running(self):
# Providing a root dir for logging.
callback = self._experiment.DataIO.Callback
callback.set_log_dir(self.__get_target_dir())

def __create_opinion_collection(self, linked_iter, supported_labels):
return fill_opinion_collection(
create_opinion_collection=self._experiment.OpinionOperations.create_opinion_collection,
linked_data_iter=linked_iter,
labels_helper=SingleLabelsHelper(self.__label_scaler),
to_opinion_func=LanguageModelExperimentEvaluator.__create_labeled_opinion,
label_calc_mode=LabelCalculationMode.AVERAGE,
supported_labels=supported_labels)

@staticmethod
def __create_labeled_opinion(item, label):
assert(isinstance(item, Opinion))
return Opinion(source_value=item.SourceValue,
target_value=item.TargetValue,
sentiment=label)
35 changes: 29 additions & 6 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.experiment.output.views.multiple import MulticlassOutputView
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.opinions.base import Opinion
from arekit.common.opinions.provider import OpinionCollectionsProvider
from arekit.common.utils import progress_bar_iter
from arekit.contrib.networks.core.callback.utils_hidden_states import save_minibatch_all_input_dependent_hidden_values
Expand Down Expand Up @@ -104,16 +107,19 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc

cmp_doc_ids_set = set(doc_ops.iter_doc_ids_to_compare())

output_view = MulticlassOutputView(labels_scaler=labels_scaler,
storage=output_storage)

# Extract iterator.
collections_iter = OutputToOpinionCollectionsConverter.iter_opinion_collections(
output_view=output_view,
opinions_view=exp_io.create_opinions_view(data_type),
labels_scaler=labels_scaler,
create_opinion_collection_func=opin_ops.create_opinion_collection,
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
label_calculation_mode=label_calc_mode,
supported_labels=supported_collection_labels,
output_formatter=MulticlassOutputView(labels_scaler=labels_scaler,
storage=output_storage))
to_collection_func=lambda linked_iter: __create_opinion_collection(
linked_iter=linked_iter,
supported_labels=supported_collection_labels,
create_opinion_collection=opin_ops.create_opinion_collection,
label_scaler=labels_scaler))

# Save collection.
for doc_id, collection in __log_wrap_collections_conversion_iter(collections_iter):
Expand All @@ -131,6 +137,23 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc
target=target)


def __create_opinion_collection(linked_iter, supported_labels, label_scaler, create_opinion_collection):
return fill_opinion_collection(
create_opinion_collection=create_opinion_collection,
linked_data_iter=linked_iter,
labels_helper=SingleLabelsHelper(label_scaler),
to_opinion_func=__create_labeled_opinion,
label_calc_mode=LabelCalculationMode.AVERAGE,
supported_labels=supported_labels)


def __create_labeled_opinion(item, label):
assert(isinstance(item, Opinion))
return Opinion(source_value=item.SourceValue,
target_value=item.TargetValue,
sentiment=label)


def __log_wrap_samples_iter(it):
return progress_bar_iter(iterable=it,
desc='Writing output',
Expand Down

0 comments on commit 6c383d1

Please sign in to comment.