Skip to content

Commit

Permalink
#362 done.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 11, 2022
1 parent a699b52 commit 897393e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 30 deletions.
12 changes: 7 additions & 5 deletions arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ def __handle_iteration(self, data_type, pipeline):

InputDataSerializationHelper.serialize(
pipeline=pipeline,
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance_func=lambda dtype: self.__balance_func(data_type),
data_type=data_type,
doc_ids_iter=self.__doc_ops.iter_doc_ids(data_type),
keep_labels=self.__save_labels_func(data_type),
do_balance=self.__balance_func(data_type),
opinions_writer=self.__exp_io.create_opinions_writer(),
samples_writer=self.__exp_io.create_samples_writer(),
samples_target=self.__exp_io.create_samples_writer_target(data_type),
opinions_target=self.__exp_io.create_opinions_writer_target(data_type),
sample_rows_provider=self.__sample_rows_provider)

# endregion
Expand Down
12 changes: 7 additions & 5 deletions arekit/contrib/networks/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ def __handle_iteration(self, data_type, pipeline, rows_provider):
# Perform data serialization.
InputDataSerializationHelper.serialize(
pipeline=pipeline,
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance_func=lambda dtype: self.__balance_func(data_type),
data_type=data_type,
doc_ids_iter=self.__doc_ops.iter_doc_ids(data_type),
keep_labels=self.__save_labels_func(data_type),
do_balance=self.__balance_func(data_type),
opinions_writer=self.__exp_io.create_opinions_writer(),
samples_writer=self.__exp_io.create_samples_writer(),
samples_target=self.__exp_io.create_samples_writer_target(data_type),
opinions_target=self.__exp_io.create_opinions_writer_target(data_type),
sample_rows_provider=rows_provider)

# endregion
Expand Down
34 changes: 14 additions & 20 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import logging

from arekit.common.data.input.providers.columns.opinion import OpinionColumnsProvider
Expand All @@ -16,51 +17,44 @@
class InputDataSerializationHelper(object):

@staticmethod
def serialize(pipeline, exp_io, iter_doc_ids_func, keep_labels_func,
balance_func, data_type, sample_rows_provider):
def serialize(pipeline, doc_ids_iter, keep_labels, do_balance, sample_rows_provider,
samples_writer, samples_target, opinions_writer, opinions_target):
""" pipeline:
note, it is important to provide a pipeline which results in linked opinions iteration
for a particular document.
document (id, instance) -> ... -> linked opinion list
keep_labels_func: function
data_type -> bool
balance_func: function
data_type -> bool
iter_doc_ids:
func(data_type)
"""
assert(isinstance(pipeline, BasePipeline))
assert(callable(iter_doc_ids_func))
assert(callable(keep_labels_func))
assert(callable(balance_func))
assert(isinstance(doc_ids_iter, collections.Iterable))
assert(isinstance(keep_labels, bool))
assert(isinstance(do_balance, bool))

opinions_repo = BaseInputOpinionsRepository(
columns_provider=OpinionColumnsProvider(),
rows_provider=BaseOpinionsRowProvider(),
storage=BaseRowsStorage())

samples_repo = BaseInputSamplesRepository(
columns_provider=SampleColumnsProvider(store_labels=keep_labels_func(data_type)),
columns_provider=SampleColumnsProvider(store_labels=keep_labels),
rows_provider=sample_rows_provider,
storage=BaseRowsStorage())

opinion_provider = InputTextOpinionProvider(pipeline)

doc_ids = list(doc_ids_iter)

# Populate repositories
opinions_repo.populate(opinion_provider=opinion_provider,
doc_ids=list(iter_doc_ids_func(data_type)),
doc_ids=doc_ids,
desc="opinion")

samples_repo.populate(opinion_provider=opinion_provider,
doc_ids=list(iter_doc_ids_func(data_type)),
doc_ids=doc_ids,
desc="sample")

if balance_func(data_type):
if do_balance:
samples_repo.balance()

# Write repositories
samples_repo.write(writer=exp_io.create_samples_writer(),
target=exp_io.create_samples_writer_target(data_type=data_type))

opinions_repo.write(writer=exp_io.create_opinions_writer(),
target=exp_io.create_opinions_writer_target(data_type=data_type))
samples_repo.write(writer=samples_writer, target=samples_target())
opinions_repo.write(writer=opinions_writer, target=opinions_target())

0 comments on commit 897393e

Please sign in to comment.