Skip to content

Commit

Permalink
#459 done, caused #460, #458 fixed. Large update that provides a stor…
Browse files Browse the repository at this point in the history
…age that keeps only one row in memory! The latter is supplemented by CSV writer.
  • Loading branch information
nicolay-r committed May 8, 2023
1 parent 5e92893 commit cf6a5de
Show file tree
Hide file tree
Showing 15 changed files with 262 additions and 27 deletions.
20 changes: 17 additions & 3 deletions arekit/common/data/input/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from arekit.common.data.input.providers.opinions import InputTextOpinionProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter


class BaseInputRepository(object):
Expand Down Expand Up @@ -30,11 +32,13 @@ def _setup_rows_provider(self):
# endregion

# TODO. Generailze, TextOpinion -> Any provider.
def populate(self, opinion_provider, doc_ids, desc=""):
def populate(self, opinion_provider, doc_ids, desc="", writer=None, target=None):
# TODO. Generailze, TextOpinion -> Any provider.
assert(isinstance(opinion_provider, InputTextOpinionProvider))
assert(isinstance(self._storage, BaseRowsStorage))
assert(isinstance(doc_ids, list))
assert(isinstance(writer, BaseWriter) or writer is None)
assert(isinstance(target, str) or target is None)

def iter_rows(idle_mode):
return self._rows_provider.iter_by_rows(
Expand All @@ -44,12 +48,22 @@ def iter_rows(idle_mode):

self._storage.init_empty(columns_provider=self._columns_provider)

is_async_write_mode_on = writer is not None and target is not None

if is_async_write_mode_on:
writer.open_target(target)

self._storage.fill(lambda idle_mode: iter_rows(idle_mode),
columns_provider=self._columns_provider,
row_handler=lambda: writer.commit_line(self._storage) if is_async_write_mode_on else None,
desc=desc)

def write(self, writer, target, free_storage=True):
writer.write(self._storage, target)
if is_async_write_mode_on:
writer.close_target()

def push(self, writer, target, free_storage=True):
if not isinstance(self._storage, RowCacheStorage):
writer.write_all(self._storage, target)

# After writing we free the contents of the storage.
if free_storage:
Expand Down
38 changes: 29 additions & 9 deletions arekit/common/data/storages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,29 @@
import logging

from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
from arekit.common.utils import progress_bar_defined
from arekit.common.utils import progress_bar

logger = logging.getLogger(__name__)


class BaseRowsStorage(object):

# region protected methods

def _begin_filling_row(self, row_ind):
pass

# endregion

# region abstract methods

def _set_value(self, row_ind, column, value):
def _set_row_value(self, row_ind, column, value):
raise NotImplemented()

def _iter_rows(self):
""" returns: tuple(int, list)
provides the index (int) and the related content of the row (list)
"""
raise NotImplemented()

def _get_rows_count(self):
Expand All @@ -41,21 +51,31 @@ def init_empty(self, columns_provider):
def iter_shuffled(self):
raise NotImplemented()

def iter_column_names(self):
raise NotImplemented()

# endregion

def fill(self, iter_rows_func, columns_provider, rows_count=None, desc=""):
def fill(self, iter_rows_func, columns_provider, row_handler=None, rows_count=None, desc=""):
assert(callable(iter_rows_func))
assert(isinstance(columns_provider, BaseColumnsProvider))
assert(callable(row_handler) or row_handler is None)

it = progress_bar_defined(iterable=iter_rows_func(False),
desc="{fmt}".format(fmt=desc),
total=rows_count)
it = progress_bar(iterable=iter_rows_func(False),
desc="{fmt}".format(fmt=desc),
total=rows_count)

for row_index, row in enumerate(it):

self._begin_filling_row(row_index)

for column, value in row.items():
self._set_value(row_ind=row_index,
column=column,
value=value)
self._set_row_value(row_ind=row_index,
column=column,
value=value)

if row_handler is not None:
row_handler()

def free(self):
gc.collect()
Expand Down
7 changes: 7 additions & 0 deletions arekit/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def split_by_whitespaces(text):
return text.split()


def progress_bar(iterable, total, desc="", unit="it"):
if total is not None:
return progress_bar_defined(iterable=iterable, total=total, desc=desc, unit=unit)
else:
return progress_bar_iter(iterable=iterable, desc=desc, unit=unit)


def progress_bar_defined(iterable, total, desc="", unit="it"):
return tqdm(iterable=iterable,
total=total,
Expand Down
8 changes: 6 additions & 2 deletions arekit/contrib/utils/data/storages/pandas_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def __fill_with_blank_rows(self, row_id_column_name, rows_count):

# region protected methods

def _set_value(self, row_ind, column, value):
def iter_column_names(self):
return iter(self._df.columns)

def _set_row_value(self, row_ind, column, value):
self._df.at[row_ind, column] = value

def _iter_rows(self):
Expand All @@ -59,7 +62,7 @@ def _get_rows_count(self):

# region public methods

def fill(self, iter_rows_func, columns_provider, rows_count=None, desc=""):
def fill(self, iter_rows_func, columns_provider, row_handler=None, rows_count=None, desc=""):
""" NOTE: We provide the rows counting which is required
in order to know an expected amount of rows in advace
due to the specifics of the pandas memory allocation
Expand All @@ -83,6 +86,7 @@ def fill(self, iter_rows_func, columns_provider, rows_count=None, desc=""):
logger.info("Completed!")

super(PandasBasedRowsStorage, self).fill(iter_rows_func=iter_rows_func,
row_handler=row_handler,
columns_provider=columns_provider,
rows_count=rows_count)

Expand Down
30 changes: 30 additions & 0 deletions arekit/contrib/utils/data/storages/row_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from arekit.common.data.input.providers.columns.base import BaseColumnsProvider
from arekit.common.data.storages.base import BaseRowsStorage


class RowCacheStorage(BaseRowsStorage):
""" Row Caching storage kernel, based on python dictionary.
"""

def __init__(self):
self.__f = None
self.__row_cache = {}
self.__columns = []

@property
def RowCache(self):
return self.__row_cache

def init_empty(self, columns_provider):
assert(isinstance(columns_provider, BaseColumnsProvider))
for col_name, _ in columns_provider.get_columns_list_with_types():
self.__columns.append(col_name)

def iter_column_names(self):
return iter(self.__columns)

def _set_row_value(self, row_ind, column, value):
self.__row_cache[column] = value

def _begin_filling_row(self, row_ind):
self.__row_cache.clear()
20 changes: 19 additions & 1 deletion arekit/contrib/utils/data/writers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
class BaseWriter(object):

def write(self, storage, target):
def open_target(self, target):
pass

def commit_line(self, storage):
pass

def close_target(self):
pass

def write_all(self, storage, target):
""" Performs the writing process of the whole storage.
The implementation and support of the related operation
may vary and depends on the nature of the storage, which
briefly might keep all the data in memory (available)
or cache only temporary information (unavailable)
storage: BaseRowsStorage
target: str
"""
raise NotImplementedError()
41 changes: 41 additions & 0 deletions arekit/contrib/utils/data/writers/csv_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import csv
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.base import BaseWriter


class NativeCsvWriter(BaseWriter):

def __init__(self, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL):
self.__target_f = None
self.__writer = None
self.__create_writer_func = lambda f: csv.writer(
f, delimiter=delimiter, quotechar=quotechar, quoting=quoting)

def open_target(self, target):
self.__target_f = open(target, "w")
self.__writer = self.__create_writer_func(self.__target_f)
pass

def close_target(self):
self.__target_f.close()

def commit_line(self, storage):
assert(isinstance(storage, RowCacheStorage))
assert(self.__writer is not None)
line_data = [storage.RowCache[col_name] for col_name in storage.iter_column_names()
if col_name in storage.RowCache]
self.__writer.writerow(line_data)

def write_all(self, storage, target):
""" Writes all the `storage` rows
into the `target` filepath, formatted as CSV.
"""
assert(isinstance(storage, BaseRowsStorage))

with open(target, "w") as f:
writer = self.__create_writer_func(f)
for _, row in storage:
#content = [row[col_name] for col_name in storage.iter_column_names()]
content = [v for v in row]
writer.writerow(content)
2 changes: 1 addition & 1 deletion arekit/contrib/utils/data/writers/csv_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, write_header):
super(PandasCsvWriter, self).__init__()
self.__write_header = write_header

def write(self, storage, target):
def write_all(self, storage, target):
assert(isinstance(storage, PandasBasedRowsStorage))
assert(isinstance(target, str))

Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/utils/data/writers/json_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __write_bag(bag, json_file):
json.dump(bag, json_file, separators=(",", ":"), ensure_ascii=False)
json_file.write("\n")

def write(self, storage, target):
def write_all(self, storage, target):
assert(isinstance(storage, BaseRowsStorage))
assert(isinstance(target, str))

Expand Down
6 changes: 4 additions & 2 deletions arekit/contrib/utils/pipelines/items/sampling/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class BertExperimentInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_func):
def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_func, storage):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -25,6 +25,7 @@ def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_f
self.__balance_func = balance_func
self.__samples_io = samples_io
self.__save_labels_func = save_labels_func
self.__storage = storage

# region private methods

Expand All @@ -34,7 +35,8 @@ def __serialize_iteration(self, data_type, pipeline, data_folding):
repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=self.__sample_rows_provider),
rows_provider=self.__sample_rows_provider,
storage=self.__storage),
}

writer_and_targets = {
Expand Down
6 changes: 4 additions & 2 deletions arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class NetworksInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx,
samples_io, emb_io, balance_func, save_embedding):
samples_io, emb_io, balance_func, save_embedding, storage):
""" This pipeline item allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, ctx,
self.__save_embedding = save_embedding and vectorizers is not None
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func
self.__storage = storage

self.__term_embedding_pairs = collections.OrderedDict()

Expand Down Expand Up @@ -102,7 +103,8 @@ def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding
repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=rows_provider),
rows_provider=rows_provider,
storage=self.__storage),
}

writer_and_targets = {
Expand Down
13 changes: 9 additions & 4 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.input.repositories.base import BaseInputRepository
from arekit.common.data.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.pipeline.base import BasePipeline
from arekit.contrib.utils.data.service.balance import StorageBalancing
from arekit.contrib.utils.data.storages.pandas_based import PandasBasedRowsStorage
Expand All @@ -19,13 +20,14 @@
class InputDataSerializationHelper(object):

@staticmethod
def create_samples_repo(keep_labels, rows_provider):
def create_samples_repo(keep_labels, rows_provider, storage):
assert(isinstance(rows_provider, BaseRowProvider))
assert(isinstance(keep_labels, bool))
assert(isinstance(storage, BaseRowsStorage))
return BaseInputSamplesRepository(
columns_provider=SampleColumnsProvider(store_labels=keep_labels),
rows_provider=rows_provider,
storage=PandasBasedRowsStorage())
storage=storage)

@staticmethod
def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc="", do_balance=False):
Expand All @@ -35,9 +37,12 @@ def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc="", do_bal
assert(isinstance(do_balance, bool))

doc_ids = list(doc_ids_iter)

repo.populate(opinion_provider=InputTextOpinionProvider(pipeline),
doc_ids=doc_ids,
desc=desc)
desc=desc,
writer=writer,
target=target)

if do_balance:
balanced_storage = StorageBalancing.create_balanced_from(
Expand All @@ -47,4 +52,4 @@ def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc="", do_bal
rows_provider=repo._rows_provider,
storage=balanced_storage)

repo.write(writer=writer, target=target)
repo.push(writer=writer, target=target)
Loading

0 comments on commit cf6a5de

Please sign in to comment.