Skip to content

Commit

Permalink
Refactoring. Removed opinion converter (#198 related)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Oct 2, 2021
1 parent 6c383d1 commit 0fbf35f
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 47 deletions.
Empty file.
22 changes: 0 additions & 22 deletions arekit/common/experiment/output/opinions/converter.py

This file was deleted.

31 changes: 23 additions & 8 deletions arekit/common/experiment/output/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def __iter_linked_opinions_df(self, news_id):
linked_opins_df = news_df[news_df[const.ID].str.contains(opin_id_pattern)]
yield linked_opins_df

def __iter_linked_opinions(self, news_id, opinions_view):
assert (isinstance(news_id, int))
assert (isinstance(opinions_view, BaseOpinionStorageView))

for linked_df in self.__iter_linked_opinions_df(news_id=news_id):
assert (isinstance(linked_df, pd.DataFrame))

opinions_iter = self._iter_by_opinions(linked_df=linked_df,
opinions_view=opinions_view)

yield LinkedOpinionWrapper(linked_data=opinions_iter)

# endregion

# region protected methods
Expand Down Expand Up @@ -71,16 +83,19 @@ def iter_news_ids(self):
unique_news_ids = set(self._storage.iter_column_values(column_name=const.NEWS_ID))
return unique_news_ids

def iter_linked_opinions(self, news_id, opinions_view):
assert (isinstance(news_id, int))
assert (isinstance(opinions_view, BaseOpinionStorageView))
def iter_opinion_collections(self, opinions_view, keep_doc_id_func, to_collection_func):
assert(isinstance(opinions_view, BaseOpinionStorageView))
assert(callable(keep_doc_id_func))
assert(callable(to_collection_func))

for linked_df in self.__iter_linked_opinions_df(news_id=news_id):
assert (isinstance(linked_df, pd.DataFrame))
for news_id in self.iter_news_ids():

opinions_iter = self._iter_by_opinions(linked_df=linked_df,
opinions_view=opinions_view)
if not keep_doc_id_func(news_id):
continue

yield LinkedOpinionWrapper(linked_data=opinions_iter)
linked_data_iter = self.__iter_linked_opinions(news_id=news_id,
opinions_view=opinions_view)

yield news_id, to_collection_func(linked_data_iter)

# endregion
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from arekit.common.opinions.base import Opinion


def fill_opinion_collection(create_opinion_collection, linked_data_iter,
labels_helper, to_opinion_func,
label_calc_mode, supported_labels=None):
def create_and_fill_opinion_collection(
create_opinion_collection, linked_data_iter,
labels_helper, to_opinion_func,
label_calc_mode, supported_labels=None):
""" to_opinion_func: (item, label) -> opinion
"""
assert(callable(create_opinion_collection))
Expand Down Expand Up @@ -38,4 +39,4 @@ def fill_opinion_collection(create_opinion_collection, linked_data_iter,

collection.add_opinion(agg_opinion)

return collection
return collection
10 changes: 4 additions & 6 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.engine import ExperimentEngine
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.linked.helper import create_and_fill_opinion_collection
from arekit.common.experiment.output.views.multiple import MulticlassOutputView
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
Expand Down Expand Up @@ -122,8 +121,7 @@ def _handle_iteration(self, iter_index):
storage=storage)

# iterate opinion collections.
collections_iter = OutputToOpinionCollectionsConverter.iter_opinion_collections(
output_view=output_view,
collections_iter = output_view.iter_opinion_collections(
opinions_view=exp_io.create_opinions_view(self.__data_type),
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
to_collection_func=lambda linked_iter: self.__create_opinion_collection(
Expand Down Expand Up @@ -164,8 +162,8 @@ def _before_running(self):
callback.set_log_dir(self.__get_target_dir())

def __create_opinion_collection(self, linked_iter, supported_labels):
return fill_opinion_collection(
create_opinion_collection=self._experiment.OpinionOperations.create_opinion_collection,
return create_and_fill_opinion_collection(
create_opinion_collection=self._experiment.OpinionOperations.create_and_fill_opinion_collection,
linked_data_iter=linked_iter,
labels_helper=SingleLabelsHelper(self.__label_scaler),
to_opinion_func=LanguageModelExperimentEvaluator.__create_labeled_opinion,
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/experiment_rusentrel/annot/two_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _annot_collection_core(self, parsed_news, data_type, doc_ops, opin_ops):
assert(isinstance(data_type, DataType))

doc_id = parsed_news.RelatedNewsID
neut_collection = opin_ops.create_opinion_collection()
neut_collection = opin_ops.create_and_fill_opinion_collection()
assert(isinstance(neut_collection, OpinionCollection))

# We copy all the opinions from etalon collection
Expand Down
9 changes: 3 additions & 6 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from arekit.common.experiment.api.ops_doc import DocumentOperations
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.experiment.output.views.multiple import MulticlassOutputView
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.labels.scaler import BaseLabelScaler
Expand Down Expand Up @@ -111,14 +109,13 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc
storage=output_storage)

# Extract iterator.
collections_iter = OutputToOpinionCollectionsConverter.iter_opinion_collections(
output_view=output_view,
collections_iter = output_view.iter_opinion_collections(
opinions_view=exp_io.create_opinions_view(data_type),
keep_doc_id_func=lambda doc_id: doc_id in cmp_doc_ids_set,
to_collection_func=lambda linked_iter: __create_opinion_collection(
linked_iter=linked_iter,
supported_labels=supported_collection_labels,
create_opinion_collection=opin_ops.create_opinion_collection,
create_opinion_collection=opin_ops.create_and_fill_opinion_collection,
label_scaler=labels_scaler))

# Save collection.
Expand All @@ -138,7 +135,7 @@ def __convert_output_to_opinion_collections(exp_io, opin_ops, doc_ops, labels_sc


def __create_opinion_collection(linked_iter, supported_labels, label_scaler, create_opinion_collection):
return fill_opinion_collection(
return create_opinion_collection(
create_opinion_collection=create_opinion_collection,
linked_data_iter=linked_iter,
labels_helper=SingleLabelsHelper(label_scaler),
Expand Down

0 comments on commit 0fbf35f

Please sign in to comment.