Skip to content

Commit

Permalink
#376 done. #492 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 11, 2023
1 parent a1eb5b4 commit 5ad9545
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 109 deletions.
12 changes: 12 additions & 0 deletions arekit/common/data/input/providers/rows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,29 @@ class BaseRowProvider(object):
""" Base provider for rows that suppose to be filled into BaseRowsStorage.
"""

def __init__(self):
self.__rows_counter = None

# region protected methods

# TODO. This might be also generalized.
# TODO. Idle-mode is also a implementation and task specific parameter, i.e. might be removed from here.
def _provide_rows(self, parsed_doc, entity_service, text_opinion_linkage, idle_mode):
raise NotImplementedError()

def _count_row(self):
index = self.__rows_counter["rows_iterated"]
self.__rows_counter["rows_iterated"] += 1
return index

# endregion

def iter_by_rows(self, contents_provider, doc_ids_iter, idle_mode):
assert(isinstance(contents_provider, ContentsProvider))
assert(isinstance(doc_ids_iter, collections.Iterable))

self.__rows_counter = collections.Counter()

for linked_data in contents_provider.from_doc_ids(doc_ids=doc_ids_iter, idle_mode=idle_mode):
assert(isinstance(linked_data, LinkedDataWrapper))
assert(isinstance(linked_data.Tag, ParsedDocumentService))
Expand All @@ -39,3 +49,5 @@ def iter_by_rows(self, contents_provider, doc_ids_iter, idle_mode):

for row in rows_it:
yield linked_data.RelatedDocID, row

self.__rows_counter = None
22 changes: 4 additions & 18 deletions arekit/common/data/input/providers/rows/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from arekit.common.data.input.providers.label.multiple import MultipleLabelProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.data.row_ids.binary import BinaryIDProvider
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.entities.base import Entity
from arekit.common.labels.base import Label

Expand All @@ -34,7 +32,6 @@ def __init__(self, label_provider, text_provider):

self._label_provider = label_provider
self.__text_provider = text_provider
self.__row_ids_provider = self.__create_row_ids_provider(label_provider)
self.__instances_provider = self.__create_instances_provider(label_provider)
self.__store_labels = None

Expand Down Expand Up @@ -65,10 +62,7 @@ def _fill_row_core(self, row, text_opinion_linkage, index_in_linked, etalon_labe
def __assign_value(column, value):
row[column] = value

row[const.ID] = self.__row_ids_provider.create_sample_id(
linked_opinions=text_opinion_linkage,
index_in_linked=index_in_linked,
label_scaler=self._label_provider.LabelScaler)
row[const.ID] = self._count_row()

row[const.OPINION_ID] = text_opinion_linkage.First.TextOpinionID

Expand Down Expand Up @@ -127,18 +121,9 @@ def _provide_rows(self, parsed_doc, entity_service, text_opinion_linkage, idle_m

# region private methods

@staticmethod
def __create_row_ids_provider(label_provider):
# TODO. #376 related. This should be removed after refactoring, because
# TODO. we consider an ordinary IDs, that not based on the other data.
if isinstance(label_provider, BinaryLabelProvider):
return BinaryIDProvider()
if isinstance(label_provider, MultipleLabelProvider):
return MultipleIDProvider()

@staticmethod
def __create_instances_provider(label_provider):
# TODO. #473 related: thiese label providers are based on text opinion extraction task!
# TODO. #473 related: these label providers are based on text opinion extraction task!
if isinstance(label_provider, BinaryLabelProvider):
return MultipleInstancesLinkedTextOpinionsProvider(label_provider.SupportedLabels)
if isinstance(label_provider, MultipleLabelProvider):
Expand All @@ -156,6 +141,7 @@ def __provide_rows(self, row_dict, parsed_doc, entity_service,
etalon_label = self.__instances_provider.provide_label(text_opinion_linkage)
for instance in self.__instances_provider.iter_instances(text_opinion_linkage):
yield self.__create_row(row=row_dict,
row_id=0,
parsed_doc=parsed_doc,
entity_service=entity_service,
text_opinions_linkage=instance,
Expand All @@ -164,7 +150,7 @@ def __provide_rows(self, row_dict, parsed_doc, entity_service,
etalon_label=etalon_label,
idle_mode=idle_mode)

def __create_row(self, row, parsed_doc, entity_service, text_opinions_linkage,
def __create_row(self, row, row_id, parsed_doc, entity_service, text_opinions_linkage,
index_in_linked, etalon_label, idle_mode):
"""
Composing row in following format:
Expand Down
Empty file.
45 changes: 0 additions & 45 deletions arekit/common/data/row_ids/base.py

This file was deleted.

27 changes: 0 additions & 27 deletions arekit/common/data/row_ids/binary.py

This file was deleted.

14 changes: 0 additions & 14 deletions arekit/common/data/row_ids/multiple.py

This file was deleted.

6 changes: 1 addition & 5 deletions arekit/common/data/views/samples.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.storages.base import BaseRowsStorage


# TODO. This is a particular type of view, and expected to be off the core.
class LinkedSamplesStorageView(object):

def __init__(self, row_ids_provider):
assert(isinstance(row_ids_provider, BaseIDProvider))
self.__row_ids_provider = row_ids_provider

def iter_from_storage(self, storage):
assert(isinstance(storage, BaseRowsStorage))
undefined = -1
Expand Down

0 comments on commit 5ad9545

Please sign in to comment.