Skip to content

Commit

Permalink
#360 fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 7, 2022
1 parent cb63424 commit 5659100
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
3 changes: 0 additions & 3 deletions arekit/common/experiment/api/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@ def _create_opinion_collection_writer(self):

# region public methods

def balance_samples(self, data_type, balance):
return balance and data_type == DataType.Train

def create_opinion_collection_target(self, doc_id, data_type, check_existance=False):
return self._create_annotated_collection_target(
doc_id=doc_id,
Expand Down
6 changes: 3 additions & 3 deletions arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):

def __init__(self, data_type_pipelines, sample_rows_provider, exp_io,
exp_ctx, doc_ops, save_labels_func, balance_train_samples):
exp_ctx, doc_ops, save_labels_func, balance_func):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -31,7 +31,7 @@ def __init__(self, data_type_pipelines, sample_rows_provider, exp_io,
super(BertExperimentInputSerializerIterationHandler, self).__init__()

self.__sample_rows_provider = sample_rows_provider
self.__balance_train_samples = balance_train_samples
self.__balance_func = balance_func
self.__exp_io = exp_io
self.__exp_ctx = exp_ctx
self.__doc_ops = doc_ops
Expand All @@ -48,7 +48,7 @@ def __handle_iteration(self, data_type, pipeline):
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance=self.__balance_train_samples,
balance_func=lambda dtype: self.__balance_func(data_type),
data_type=data_type,
sample_rows_provider=self.__sample_rows_provider)

Expand Down
11 changes: 6 additions & 5 deletions arekit/contrib/networks/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class NetworksInputSerializerExperimentIteration(ExperimentIterationHandler):

def __init__(self, data_type_pipelines, vectorizers, save_labels_func,
str_entity_fmt, exp_ctx, exp_io, doc_ops, balance, save_embedding):
str_entity_fmt, exp_ctx, exp_io, doc_ops, balance_func, save_embedding):
""" This hanlder allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand Down Expand Up @@ -60,19 +60,20 @@ def __init__(self, data_type_pipelines, vectorizers, save_labels_func,
assert(isinstance(doc_ops, DocumentOperations))
assert(isinstance(str_entity_fmt, StringEntitiesFormatter))
assert(isinstance(vectorizers, dict))
assert(isinstance(balance, bool))
assert(isinstance(save_embedding, bool))
assert(callable(save_labels_func))
assert(callable(balance_func))
super(NetworksInputSerializerExperimentIteration, self).__init__()

self.__data_type_pipelines = data_type_pipelines
self.__exp_ctx = exp_ctx
self.__exp_io = exp_io
self.__doc_ops = doc_ops
self.__save_labels_func = save_labels_func
self.__vectorizers = vectorizers
self.__balance = balance
self.__save_embedding = save_embedding
self.__str_entity_fmt = str_entity_fmt
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func

# region protected methods

Expand All @@ -92,7 +93,7 @@ def __handle_iteration(self, data_type, pipeline, rows_provider):
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance=self.__balance,
balance_func=lambda dtype: self.__balance_func(data_type),
data_type=data_type,
sample_rows_provider=rows_provider)

Expand Down
9 changes: 6 additions & 3 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@
class InputDataSerializationHelper(object):

@staticmethod
def serialize(pipeline, exp_io, iter_doc_ids_func, keep_labels_func, balance, data_type, sample_rows_provider):
def serialize(pipeline, exp_io, iter_doc_ids_func, keep_labels_func,
balance_func, data_type, sample_rows_provider):
""" pipeline:
note, it is important to provide a pipeline which results in linked opinions iteration
for a particular document.
document (id, instance) -> ... -> linked opinion list
keep_labels_func: function
data_type -> bool
balance_func: function
data_type -> bool
iter_doc_ids:
func(data_type)
"""
assert(isinstance(pipeline, BasePipeline))
assert(callable(iter_doc_ids_func))
assert(callable(keep_labels_func))
assert(isinstance(balance, bool))
assert(callable(balance_func))

opinions_repo = BaseInputOpinionsRepository(
columns_provider=OpinionColumnsProvider(),
Expand All @@ -52,7 +55,7 @@ def serialize(pipeline, exp_io, iter_doc_ids_func, keep_labels_func, balance, da
doc_ids=list(iter_doc_ids_func(data_type)),
desc="sample")

if exp_io.balance_samples(data_type=data_type, balance=balance):
if balance_func(data_type):
samples_repo.balance()

# Write repositories
Expand Down

0 comments on commit 5659100

Please sign in to comment.