Skip to content

Commit

Permalink
#182, refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 10, 2021
1 parent af01adb commit 294b380
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
8 changes: 8 additions & 0 deletions arekit/common/data/input/providers/instances/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class BaseLinkedTextOpinionsInstancesProvider(object):

def iter_instances(self, linked_wrap):
raise NotImplementedError()

@staticmethod
def provide_label(linked_wrap):
return linked_wrap.First.Sentiment
27 changes: 27 additions & 0 deletions arekit/common/data/input/providers/instances/multiple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from arekit.common.data.input.providers.instances.base import BaseLinkedTextOpinionsInstancesProvider
from arekit.common.linked.text_opinions.wrapper import LinkedTextOpinionsWrapper
from arekit.common.text_opinions.base import TextOpinion


class MultipleLinkedTextOpinionsInstancesProvider(BaseLinkedTextOpinionsInstancesProvider):

def __init__(self, supported_labels):
assert(isinstance(supported_labels, list))
self.__supported_labels = supported_labels

def iter_instances(self, linked_wrap):
""" Enumerate all opinions as if it would be with the different label types.
"""
for label in self.__supported_labels:
yield self.__modify_first_and_copy_linked_wrap(linked_wrap, label)

@staticmethod
def __modify_first_and_copy_linked_wrap(linked_wrap, label):
assert (isinstance(linked_wrap, LinkedTextOpinionsWrapper))

linked_text_opinions = [opinion for opinion in linked_wrap]
text_opinion_copy = TextOpinion.create_copy(other=linked_text_opinions[0])
text_opinion_copy.set_label(label=label)
linked_text_opinions[0] = text_opinion_copy

return LinkedTextOpinionsWrapper(linked_text_opinions=linked_text_opinions)
8 changes: 8 additions & 0 deletions arekit/common/data/input/providers/instances/single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from arekit.common.data.input.providers.instances.base import BaseLinkedTextOpinionsInstancesProvider


class SingleLinkedTextOpinionsInstancesProvider(BaseLinkedTextOpinionsInstancesProvider):

def iter_instances(self, wrapper):
yield wrapper
return
43 changes: 14 additions & 29 deletions arekit/common/data/input/providers/rows/samples.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections import OrderedDict

from arekit.common.data import const
from arekit.common.data.input.providers.instances.multiple import MultipleLinkedTextOpinionsInstancesProvider
from arekit.common.data.input.providers.instances.single import SingleLinkedTextOpinionsInstancesProvider
from arekit.common.data.input.providers.label.base import LabelProvider
from arekit.common.data.input.providers.label.multiple import MultipleLabelProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
Expand Down Expand Up @@ -29,6 +31,7 @@ def __init__(self, label_provider, text_provider):
self._label_provider = label_provider
self.__text_provider = text_provider
self.__row_ids_provider = self.__create_row_ids_provider(label_provider)
self.__instances_provider = self.__create_instances_provider(label_provider)
self.__store_labels = None

# region properties
Expand Down Expand Up @@ -108,6 +111,13 @@ def __create_row_ids_provider(label_provider):
if isinstance(label_provider, MultipleLabelProvider):
return MultipleIDProvider()

@staticmethod
def __create_instances_provider(label_provider):
if isinstance(label_provider, BinaryLabelProvider):
return MultipleLinkedTextOpinionsInstancesProvider(label_provider.SupportedLabels)
if isinstance(label_provider, MultipleLabelProvider):
return SingleLinkedTextOpinionsInstancesProvider()

def __provide_rows(self, row_dict, parsed_news, linked_wrap, index_in_linked, idle_mode):
"""
Providing Rows depending on row_id_formatter type
Expand All @@ -116,27 +126,14 @@ def __provide_rows(self, row_dict, parsed_news, linked_wrap, index_in_linked, id
assert(isinstance(row_dict, OrderedDict))
assert(isinstance(linked_wrap, LinkedTextOpinionsWrapper))

origin = linked_wrap.First
if isinstance(self.__row_ids_provider, BinaryIDProvider):
"""
Enumerate all opinions as if it would be with the different label types.
"""
for label in self._label_provider.SupportedLabels:
yield self.__create_row(row=row_dict,
parsed_news=parsed_news,
linked_wrap=self.__copy_modified_linked_wrap(linked_wrap, label),
index_in_linked=index_in_linked,
# TODO. provide uint_label
etalon_label=origin.Sentiment,
idle_mode=idle_mode)

if isinstance(self.__row_ids_provider, MultipleIDProvider):
etalon_label = self.__instances_provider.provide_label(linked_wrap)
for instance in self.__instances_provider.iter_instances(linked_wrap):
yield self.__create_row(row=row_dict,
parsed_news=parsed_news,
linked_wrap=linked_wrap,
linked_wrap=instance,
index_in_linked=index_in_linked,
# TODO. provide uint_label
etalon_label=origin.Sentiment,
etalon_label=etalon_label,
idle_mode=idle_mode)

def __create_row(self, row, parsed_news, linked_wrap, index_in_linked, etalon_label, idle_mode):
Expand Down Expand Up @@ -177,18 +174,6 @@ def __create_row(self, row, parsed_news, linked_wrap, index_in_linked, etalon_la
t_ind=t_ind)
return row

@staticmethod
def __copy_modified_linked_wrap(linked_wrap, label):
assert(isinstance(linked_wrap, LinkedTextOpinionsWrapper))
linked_opinions = [o for o in linked_wrap]

copy = TextOpinion.create_copy(other=linked_opinions[0])
copy.set_label(label=label)

linked_opinions[0] = copy

return LinkedTextOpinionsWrapper(linked_text_opinions=linked_opinions)

@staticmethod
def __get_opinion_end_indices(parsed_news, text_opinion):
assert(isinstance(parsed_news, ParsedNews))
Expand Down

0 comments on commit 294b380

Please sign in to comment.