Skip to content

Commit

Permalink
Refactoring: #195
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 19, 2021
1 parent 0a574e8 commit 92a4084
Show file tree
Hide file tree
Showing 26 changed files with 276 additions and 170 deletions.
9 changes: 6 additions & 3 deletions arekit/common/experiment/input/providers/columns/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ class SampleColumnsProvider(BaseColumnsProvider):
[id, text_a] -- for test
"""

def __init__(self, store_labels, text_column_names):
assert(isinstance(text_column_names, list))
def __init__(self, store_labels):
super(SampleColumnsProvider, self).__init__()
self.__store_labels = store_labels
self.__text_column_names = text_column_names
self.__text_column_names = None

# region properties

Expand Down Expand Up @@ -49,3 +48,7 @@ def get_columns_list_with_types(self):
dtypes_list.append((const.T_IND, 'int32'))

return dtypes_list

def set_text_column_names(self, text_column_names):
assert(isinstance(text_column_names, list))
self.__text_column_names = text_column_names
23 changes: 18 additions & 5 deletions arekit/common/experiment/input/providers/rows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ class BaseRowProvider(object):
""" Base provider for rows that suppose to be filled into BaseRowsStorage.
"""

def __init__(self, storage):
assert(isinstance(storage, BaseRowsStorage))
self._storage = storage
def __init__(self):
""" NOTE: storage is considered to be intialized later on,
once repository will be created.
"""
self._storage = None

# region private methods

def __iter_by_rows(self, opinion_provider, idle_mode):
assert(isinstance(opinion_provider, OpinionProvider))
Expand All @@ -26,11 +30,18 @@ def __iter_by_rows(self, opinion_provider, idle_mode):
for row in rows_it:
yield row

# endregion

# region protected methods

def _provide_rows(self, parsed_news, linked_wrapper, idle_mode):
raise NotImplementedError()

# endregion

def format(self, opinion_provider, desc=""):
assert(isinstance(opinion_provider, OpinionProvider))
assert(self._storage is not None)

logged_rows_it = progress_bar_iter(self.__iter_by_rows(opinion_provider, idle_mode=True),
desc="Calculating rows count",
Expand All @@ -53,5 +64,7 @@ def format(self, opinion_provider, desc=""):

self._storage.log_info()

def save(self):
self._storage.save()
def set_storage(self, storage):
assert(isinstance(storage, BaseRowsStorage))
assert(self._storage is None)
self._storage = storage
31 changes: 22 additions & 9 deletions arekit/common/experiment/input/providers/rows/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from arekit.common.experiment.input.providers.row_ids.multiple import MultipleIDProvider
from arekit.common.experiment.input.providers.rows.base import BaseRowProvider
from arekit.common.experiment.input.providers.text.single import BaseSingleTextProvider
from arekit.common.experiment.input.storages.sample import BaseSampleStorage
from arekit.common.labels.base import Label
from arekit.common.linked.text_opinions.wrapper import LinkedTextOpinionsWrapper
from arekit.common.news.parsed.base import ParsedNews
Expand All @@ -22,19 +21,27 @@ class BaseSampleRowProvider(BaseRowProvider):
""" Rows provider for samples storage.
"""

def __init__(self, storage, label_provider, text_provider):
assert(isinstance(storage, BaseSampleStorage))
def __init__(self, label_provider, text_provider):
assert(isinstance(label_provider, LabelProvider))
assert(isinstance(text_provider, BaseSingleTextProvider))
super(BaseSampleRowProvider, self).__init__(storage=storage)

# Initializing storage.
self._storage.set_output_labels_uint(label_provider.OutputLabelsUint)
self._storage.init_empty()
super(BaseSampleRowProvider, self).__init__()

self._label_provider = label_provider
self.__text_provider = text_provider
self.__row_ids_provider = self.__create_row_ids_provider(label_provider)
self.__store_labels = None

# region properties

@property
def LabelProvider(self):
return self._label_provider

@property
def TextProvider(self):
return self.__text_provider

# endregion

# region protected methods

Expand All @@ -44,6 +51,7 @@ def _iter_sentence_terms(parsed_news, sentence_ind):

def _fill_row_core(self, row, linked_wrap, index_in_linked, etalon_label,
parsed_news, sentence_ind, s_ind, t_ind):
assert(isinstance(self.__store_labels, bool))

def __assign_value(column, value):
row[column] = value
Expand All @@ -57,7 +65,7 @@ def __assign_value(column, value):

expected_label = linked_wrap.get_linked_label()

if self._storage.StoreLabels:
if self.__store_labels:
row[const.LABEL] = self._label_provider.calculate_output_uint_label(
expected_uint_label=self._label_provider.LabelScaler.label_to_uint(expected_label),
etalon_uint_label=self._label_provider.LabelScaler.label_to_uint(etalon_label))
Expand Down Expand Up @@ -195,3 +203,8 @@ def __get_opinion_end_indices(parsed_news, text_opinion):
return (s_ind, t_ind)

# endregion

def set_store_labels(self, store_labels):
assert(isinstance(store_labels, bool))
assert(self.__store_labels is None)
self.__store_labels = store_labels
Empty file.
44 changes: 44 additions & 0 deletions arekit/common/experiment/input/repositories/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from arekit.common.experiment.input.providers.columns.base import BaseColumnsProvider
from arekit.common.experiment.input.providers.opinions import OpinionProvider
from arekit.common.experiment.input.providers.rows.base import BaseRowProvider
from arekit.common.experiment.input.storages.base import BaseRowsStorage


class BaseInputRepository(object):

def __init__(self, columns_provider, rows_provider, storage):
assert(isinstance(columns_provider, BaseColumnsProvider))
assert(isinstance(rows_provider, BaseRowProvider))
assert(isinstance(storage, BaseRowsStorage))

self._columns_provider = columns_provider
self._rows_provider = rows_provider
self._storage = storage

# Do setup operations.
self._setup_columns_provider()
self._setup_rows_provider()
self._setup_storage()

# region protected methods

def _setup_columns_provider(self):
pass

def _setup_rows_provider(self):
self._rows_provider.set_storage(self._storage)

def _setup_storage(self):
self._storage.set_columns_provider(self._columns_provider)

# endregion

def populate(self, opinion_provider, target, desc=""):
assert(isinstance(opinion_provider, OpinionProvider))
assert(isinstance(self._storage, BaseRowsStorage))

self._storage.init_empty()

with self._storage as storage:
self._rows_provider.format(opinion_provider, desc)
return storage.save(target)
5 changes: 5 additions & 0 deletions arekit/common/experiment/input/repositories/opinions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from arekit.common.experiment.input.repositories.base import BaseInputRepository


class BaseInputOpinionsRepository(BaseInputRepository):
pass
24 changes: 24 additions & 0 deletions arekit/common/experiment/input/repositories/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from arekit.common.experiment.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.experiment.input.repositories.base import BaseInputRepository


class BaseInputSamplesRepository(BaseInputRepository):

def _setup_rows_provider(self):
""" Setup store labels.
"""
assert(isinstance(self._rows_provider, BaseSampleRowProvider))
self._rows_provider.set_store_labels(self._columns_provider.StoreLabels)

def _setup_columns_provider(self):
""" Setup text column names.
"""
text_column_names = list(self._rows_provider.TextProvider.iter_columns())
self._columns_provider.set_text_column_names(text_column_names)

def _setup_storage(self):
""" Setup output labels uint
"""
super(BaseInputSamplesRepository, self)._setup_storage()
self._storage.set_output_labels_uint(
labels_uint=self._rows_provider.LabelProvider.OutputLabelsUint)
12 changes: 8 additions & 4 deletions arekit/common/experiment/input/storages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

class BaseRowsStorage(object):

def __init__(self, columns_provider):
assert(isinstance(columns_provider, BaseColumnsProvider))
self._columns_provider = columns_provider
def __init__(self):
self._columns_provider = None
self._df = None

def _create_empty(self):
Expand Down Expand Up @@ -41,10 +40,15 @@ def _dispose_dataframe(self):

# region public methods

def set_columns_provider(self, columns_provider):
assert(isinstance(columns_provider, BaseColumnsProvider))
assert(self._columns_provider is None)
self._columns_provider = columns_provider

def init_empty(self):
self._df = self._create_empty()

def save(self):
def save(self, target):
raise NotImplementedError()

# endregion
Expand Down
8 changes: 4 additions & 4 deletions arekit/common/experiment/input/storages/opinion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

class BaseOpinionsStorage(BaseRowsStorage):

def __init__(self, columns_provider):
super(BaseOpinionsStorage, self).__init__(columns_provider)
def __init__(self):
super(BaseOpinionsStorage, self).__init__()

def save(self):
def save(self, target):
""" In Memory solution, there is no need to write it.
"""
pass
super(BaseOpinionsStorage, self).save(target)
10 changes: 3 additions & 7 deletions arekit/common/experiment/input/storages/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@

class BaseSampleStorage(BaseRowsStorage):

def __init__(self, columns_provider):
super(BaseSampleStorage, self).__init__(columns_provider)
def __init__(self):
super(BaseSampleStorage, self).__init__()
self._output_labels_uint = None

@property
def StoreLabels(self):
return self._columns_provider.StoreLabels

# region private methods

def __fast_init_df(self, df, rows_count):
Expand All @@ -39,7 +35,7 @@ def set_output_labels_uint(self, labels_uint):
raise Exception("Output labels already defined!")
self._output_labels_uint = labels_uint

def save(self):
def save(self, target):
""" This might be implemented in nested classes.
The default, i.e. pandas-based storage is not considered
to be saved into the particular target.
Expand Down
15 changes: 8 additions & 7 deletions arekit/common/experiment/input/storages/tsv_opinion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

class TsvOpinionsStorage(BaseOpinionsStorage):

def __init__(self, filepath, column_provider):
super(TsvOpinionsStorage, self).__init__(column_provider)
self.__filepath = filepath
def __init__(self):
super(TsvOpinionsStorage, self).__init__()

def save(self):
logger.info("Saving... : {}".format(self.__filepath))
def save(self, target):
assert(isinstance(target, str))

create_dir_if_not_exists(self.__filepath)
logger.info("Saving... : {}".format(target))

create_dir_if_not_exists(target)

self._df.sort_values(by=[const.ID], ascending=True)
self._df.to_csv(self.__filepath,
self._df.to_csv(target,
sep='\t',
encoding='utf-8',
columns=[c for c in self._df.columns if c != self._columns_provider.ROW_ID],
Expand Down
16 changes: 7 additions & 9 deletions arekit/common/experiment/input/storages/tsv_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@

class TsvSampleStorage(BaseSampleStorage):

def __init__(self, filepath, columns_provider, balance, write_header):
assert(isinstance(filepath, str))
def __init__(self, balance, write_header):
assert(isinstance(balance, bool))
super(TsvSampleStorage, self).__init__(columns_provider)

super(TsvSampleStorage, self).__init__()
self.__balance = balance
self.__filepath = filepath
self.__write_header = write_header

def save(self):
def save(self, target):
assert(isinstance(target, str))

create_dir_if_not_exists(self.__filepath)
create_dir_if_not_exists(target)

if self.__balance:
logger.info("Start balancing...")
Expand All @@ -36,9 +34,9 @@ def save(self):

logger.info("Saving... {shape}: {filepath}".format(
shape=self._df.shape, # self._df.shape,
filepath=self.__filepath))
filepath=target))
self._df.sort_values(by=[const.ID], ascending=True)
self._df.to_csv(self.__filepath,
self._df.to_csv(target,
sep='\t',
encoding='utf-8',
columns=[c for c in self._df.columns if c != self._columns_provider.ROW_ID],
Expand Down
8 changes: 7 additions & 1 deletion arekit/common/experiment/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def create_opinions_reader(self, data_type):
def create_samples_writer(self, data_type, balance):
raise NotImplementedError()

def create_opinions_writer(self, data_type):
def create_opinions_writer(self):
raise NotImplementedError()

def create_samples_writer_target(self, data_type):
raise NotImplementedError()

def create_opinions_writer_target(self, data_type):
raise NotImplementedError()

def create_result_opinion_collection_filepath(self, data_type, doc_id, epoch_index):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from arekit.common.experiment.input.providers.row_ids.multiple import MultipleIDProvider
from arekit.common.experiment.input.readers.tsv_opinion import TsvInputOpinionReader
from arekit.common.experiment.input.readers.tsv_sample import TsvInputSampleReader
from arekit.common.experiment.input.storages.tsv_opinion import TsvOpinionsStorage
from arekit.common.experiment.input.storages.tsv_sample import TsvSampleStorage
from arekit.common.experiment.io_utils import BaseIOUtils
from arekit.common.utils import join_dir_with_subfolder_name

Expand All @@ -30,6 +32,19 @@ def create_opinions_reader(self, data_type):
opinions_tsv_filepath = self.get_input_opinions_filepath(data_type)
return TsvInputOpinionReader.from_tsv(opinions_tsv_filepath, compression='infer')

def create_opinions_writer_target(self, data_type):
return self.get_input_opinions_filepath(data_type)

def create_samples_writer_target(self, data_type):
return self.get_input_sample_filepath(data_type)

def create_samples_writer(self, data_type, balance):
return TsvSampleStorage(balance=balance and data_type == DataType.Train,
write_header=True)

def create_opinions_writer(self):
return TsvOpinionsStorage()

def create_result_opinion_collection_filepath(self, data_type, doc_id, epoch_index):
""" Utilized for results evaluation.
"""
Expand Down
Loading

0 comments on commit 92a4084

Please sign in to comment.