Skip to content

Commit

Permalink
#240 related.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 24, 2021
1 parent 309691f commit fc007a0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 30 deletions.
28 changes: 7 additions & 21 deletions arekit/common/data/views/ouput_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from arekit.common.data.views.base import BaseStorageView
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.linkage.opinions import OpinionsLinkage
from arekit.common.opinions.base import Opinion


class BaseOutputView(BaseStorageView):
Expand All @@ -20,16 +19,11 @@ def __init__(self, ids_provider, storage):
# region private methods

@staticmethod
def __iter_opinion_linkages_df(doc_df, row_ids):
def _iter_opinion_linkages_df(doc_df, row_ids):
for row_id in row_ids:
df_linkage = doc_df[doc_df[const.ID].str.contains(row_id)]
yield df_linkage

def __iter_id_patterns(self, opinion_ids):
for opinion_id in set(opinion_ids):
yield self._ids_provider.create_pattern(id_value=opinion_id,
p_type=BaseIDProvider.OPINION)

def __iter_doc_opinion_ids(self, doc_df):
assert (isinstance(doc_df, pd.DataFrame))
return [self._ids_provider.parse_opinion_in_opinion_id(row_id)
Expand All @@ -47,18 +41,6 @@ def __iter_opinions_by_linkages(self, linkages_df, opinions_view):
def _iter_by_opinions(self, linked_df, opinions_view):
raise NotImplementedError()

def _compose_opinion_by_opinion_id(self, sample_id, opinions_view, calc_label_func):
assert(isinstance(sample_id, str))
assert(isinstance(opinions_view, BaseOpinionStorageView))
assert(callable(calc_label_func))

opinion_id = self._ids_provider.convert_sample_id_to_opinion_id(sample_id=sample_id)
source, target = opinions_view.provide_opinion_info_by_opinion_id(opinion_id=opinion_id)

return Opinion(source_value=source,
target_value=target,
sentiment=calc_label_func())

# endregion

# region public methods
Expand All @@ -71,8 +53,12 @@ def iter_opinion_linkages(self, doc_id, opinions_view):
doc_df = self._storage.find_by_value(column_name=const.DOC_ID, value=doc_id)

doc_opin_ids = self.__iter_doc_opinion_ids(doc_df)
doc_opin_id_patterns = self.__iter_id_patterns(doc_opin_ids)
linkages_df = self.__iter_opinion_linkages_df(doc_df=doc_df, row_ids=doc_opin_id_patterns)

doc_opin_id_patterns = map(
lambda opinion_id: self._ids_provider.create_pattern(id_value=opinion_id, p_type=BaseIDProvider.OPINION),
doc_opin_ids)

linkages_df = self._iter_opinion_linkages_df(doc_df=doc_df, row_ids=doc_opin_id_patterns)
opinions_iter = self.__iter_opinions_by_linkages(linkages_df, opinions_view=opinions_view)

return map(lambda opinions: OpinionsLinkage(opinions), opinions_iter)
Expand Down
6 changes: 4 additions & 2 deletions arekit/common/data/views/output_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from arekit.common.data import const
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.views import utils
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.data.views.ouput_base import BaseOutputView
from arekit.common.labels.scaler import BaseLabelScaler
Expand Down Expand Up @@ -37,8 +38,9 @@ def _iter_by_opinions(self, linked_df, opinions_view):
assert(isinstance(linked_df, pd.DataFrame))
assert(isinstance(opinions_view, BaseOpinionStorageView))

for index, series in linked_df.iterrows():
yield self._compose_opinion_by_opinion_id(
for _, series in linked_df.iterrows():
yield utils.compose_opinion_by_opinion_id(
ids_provider=self._ids_provider,
sample_id=series[const.ID],
opinions_view=opinions_view,
calc_label_func=lambda: self.__calculate_label(series))
Expand Down
17 changes: 17 additions & 0 deletions arekit/common/data/views/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.opinions.base import Opinion


def compose_opinion_by_opinion_id(ids_provider, sample_id, opinions_view, calc_label_func):
assert(isinstance(ids_provider, BaseIDProvider))
assert(isinstance(sample_id, str))
assert(isinstance(opinions_view, BaseOpinionStorageView))
assert(callable(calc_label_func))

opinion_id = ids_provider.convert_sample_id_to_opinion_id(sample_id=sample_id)
source, target = opinions_view.provide_opinion_info_by_opinion_id(opinion_id=opinion_id)

return Opinion(source_value=source,
target_value=target,
sentiment=calc_label_func())
20 changes: 13 additions & 7 deletions arekit/contrib/bert/views/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.views import utils
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.data.views.ouput_base import BaseOutputView
from arekit.common.labels.scaler import BaseLabelScaler
Expand Down Expand Up @@ -46,14 +47,19 @@ def _iter_by_opinions(self, linked_df, opinions_view):
assert(isinstance(linked_df, pd.DataFrame))
assert(isinstance(opinions_view, BaseOpinionStorageView))

for opinion_ind in self.__iter_linked_opinion_indices(linked_df=linked_df):
ind_pattern = self._ids_provider.create_pattern(id_value=opinion_ind,
p_type=BaseIDProvider.INDEX)
opinion_df = linked_df[linked_df[const.ID].str.contains(ind_pattern)]
opinion_ids = self.__iter_linked_opinion_indices(linked_df=linked_df)

yield self._compose_opinion_by_opinion_id(
sample_id=opinion_df[const.ID].iloc[0],
id_patterns_iter = map(
lambda opinion_id: self._ids_provider.create_pattern(id_value=opinion_id, p_type=BaseIDProvider.INDEX),
opinion_ids)

linkages_dfs = self._iter_opinion_linkages_df(doc_df=linked_df, row_ids=id_patterns_iter)

for opinion_linkage_df in linkages_dfs:
yield utils.compose_opinion_by_opinion_id(
ids_provider=self._ids_provider,
sample_id=opinion_linkage_df[const.ID].iloc[0],
opinions_view=opinions_view,
calc_label_func=lambda: self.__calculate_label(df=opinion_df))
calc_label_func=lambda: self.__calculate_label(df=opinion_linkage_df))

# endregion

0 comments on commit fc007a0

Please sign in to comment.