Skip to content

Commit

Permalink
#343 related and #322 . We now provide API for managing the presence …
Browse files Browse the repository at this point in the history
…of label column in serialized data.
  • Loading branch information
nicolay-r committed Jun 21, 2022
1 parent 11ba370 commit c4ee4bb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
8 changes: 7 additions & 1 deletion arekit/contrib/bert/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

class BertExperimentInputSerializerIterationHandler(ExperimentIterationHandler):

def __init__(self, pipeline, sample_rows_provider, exp_io, exp_ctx, doc_ops, balance_train_samples, data_types):
def __init__(self, pipeline, sample_rows_provider, exp_io, exp_ctx, doc_ops,
save_labels_func, balance_train_samples, data_types):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -16,6 +17,9 @@ def __init__(self, pipeline, sample_rows_provider, exp_io, exp_ctx, doc_ops, bal
doc_id -> parsed_news -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
save_labels_func: function
data_type -> bool
data_types: list
data_types, for which the data will be generated; required for:
- document ids
Expand All @@ -32,6 +36,7 @@ def __init__(self, pipeline, sample_rows_provider, exp_io, exp_ctx, doc_ops, bal
self.__doc_ops = doc_ops
self.__pipeline = pipeline
self.__data_types = data_types
self.__save_labels_func = save_labels_func

# region private methods

Expand All @@ -42,6 +47,7 @@ def __handle_iteration(self, data_type):
pipeline=self.__pipeline,
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance=self.__balance_train_samples,
data_type=data_type,
sample_rows_provider=self.__sample_rows_provider)
Expand Down
7 changes: 6 additions & 1 deletion arekit/contrib/networks/handlers/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class NetworksInputSerializerExperimentIteration(ExperimentIterationHandler):

def __init__(self, data_type_pipelines, exp_ctx, exp_io, doc_ops, balance):
def __init__(self, data_type_pipelines, save_labels_func, exp_ctx, exp_io, doc_ops, balance):
""" This hanlder allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand All @@ -29,6 +29,9 @@ def __init__(self, data_type_pipelines, exp_ctx, exp_io, doc_ops, balance):
balance: bool
declares whethere there is a need to balance Train samples
save_labels_func: function
data_type -> bool
data_type_pipelines: dict of, for example:
{
DataType.Train: BasePipeline,
Expand All @@ -50,6 +53,7 @@ def __init__(self, data_type_pipelines, exp_ctx, exp_io, doc_ops, balance):
self.__exp_io = exp_io
self.__doc_ops = doc_ops
self.__balance = balance
self.__save_labels_func = save_labels_func

# region protected methods

Expand All @@ -68,6 +72,7 @@ def __handle_iteration(self, data_type, pipeline, rows_provider):
pipeline=pipeline,
exp_io=self.__exp_io,
iter_doc_ids_func=lambda dtype: self.__doc_ops.iter_doc_ids(dtype),
keep_labels_func=lambda dtype: self.__save_labels_func(data_type),
balance=self.__balance,
data_type=data_type,
sample_rows_provider=rows_provider)
Expand Down
7 changes: 5 additions & 2 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
class InputDataSerializationHelper(object):

@staticmethod
def serialize(pipeline, exp_io, iter_doc_ids_func, balance, data_type, sample_rows_provider):
def serialize(pipeline, exp_io, iter_doc_ids_func, keep_labels_func, balance, data_type, sample_rows_provider):
""" pipeline:
note, it is important to provide a pipeline which results in linked opinions iteration
for a particular document.
document (id, instance) -> ... -> linked opinion list
keep_labels_func: function
data_type -> bool
iter_doc_ids:
func(data_type)
"""
assert(isinstance(pipeline, BasePipeline))
assert(callable(iter_doc_ids_func))
assert(callable(keep_labels_func))
assert(isinstance(balance, bool))

opinions_repo = BaseInputOpinionsRepository(
Expand All @@ -34,7 +37,7 @@ def serialize(pipeline, exp_io, iter_doc_ids_func, balance, data_type, sample_ro
storage=BaseRowsStorage())

samples_repo = BaseInputSamplesRepository(
columns_provider=SampleColumnsProvider(store_labels=True),
columns_provider=SampleColumnsProvider(store_labels=keep_labels_func(data_type)),
rows_provider=sample_rows_provider,
storage=BaseRowsStorage())

Expand Down

0 comments on commit c4ee4bb

Please sign in to comment.