Skip to content

Commit

Permalink
#262 associated, Simplify DataIO.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jan 20, 2022
1 parent f9ed7d1 commit 5cb7ac7
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 20 deletions.
10 changes: 0 additions & 10 deletions arekit/common/experiment/api/ctx_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@


class DataIO(object):
""" This base class aggregates all the data necessary for
cv-based experiment organization
(data-serialization, training, etc.).
"""

def __init__(self):
self.__model_io = None
Expand All @@ -22,12 +18,6 @@ def ModelIO(self):
def LabelsCount(self):
raise NotImplementedError()

@property
def SupportedCollectionLabels(self):
""" All labels considered as supported and might appear in OpinionCollection by default.
"""
return None

def set_model_io(self, model_io):
""" Providing model_io in experiment data.
"""
Expand Down
9 changes: 4 additions & 5 deletions arekit/common/experiment/pipelines/opinion_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __linkages_to_opinions(linkages_iter, labels_helper, label_calc_mode):
yield __create_labeled_opinion(linkage.First, agg_label)


def __create_and_fill_opinion_collection(opinions_iter, collection, supported_labels):
def __fill_opinion_collection(opinions_iter, collection, supported_labels):
assert(isinstance(opinions_iter, collections.Iterable))
assert(isinstance(collection, OpinionCollection))
assert(isinstance(supported_labels, set) or supported_labels is None)
Expand All @@ -59,12 +59,11 @@ def __create_and_fill_opinion_collection(opinions_iter, collection, supported_la
def output_to_opinion_collections_pipeline(doc_ids_set, labels_scaler,
iter_opinion_linkages_func,
create_opinion_collection_func,
label_calc_mode, supported_labels):
label_calc_mode):
""" Opinion collection generation pipeline.
"""
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(isinstance(label_calc_mode, LabelCalculationMode))
assert(isinstance(supported_labels, set) or supported_labels is None)
assert(callable(iter_opinion_linkages_func))
assert(callable(create_opinion_collection_func))

Expand All @@ -84,8 +83,8 @@ def output_to_opinion_collections_pipeline(doc_ids_set, labels_scaler,
# Filling opinion collection.
MapPipelineItem(lambda data:
(data[0],
__create_and_fill_opinion_collection(
__fill_opinion_collection(
opinions_iter=data[1],
collection=create_opinion_collection_func(),
supported_labels=supported_labels))),
supported_labels=None))),
])
1 change: 0 additions & 1 deletion arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __run_pipeline(self, epoch_index, iter_index):
doc_ids_set=cmp_doc_ids_set,
create_opinion_collection_func=self._experiment.OpinionOperations.create_opinion_collection,
labels_scaler=self.__label_scaler,
supported_labels=exp_data.SupportedCollectionLabels,
label_calc_mode=LabelCalculationMode.AVERAGE)

# Writing opinion collection.
Expand Down
3 changes: 1 addition & 2 deletions arekit/contrib/experiment_rusentrel/model_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
doc_ids_set=set(experiment.DocumentOperations.iter_tagget_doc_ids(BaseDocumentTag.Compare)),
create_opinion_collection_func=experiment.OpinionOperations.create_opinion_collection,
labels_scaler=label_scaler,
label_calc_mode=label_calc_mode,
supported_labels=None)
label_calc_mode=label_calc_mode)

# Writing opinion collection.
save_item = HandleIterPipelineItem(
Expand Down
3 changes: 1 addition & 2 deletions tests/contrib/networks/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def test_output_formatter(self):
iter_opinion_linkages_func=lambda doc_id: linkages_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=opinion_view),
label_calc_mode=LabelCalculationMode.AVERAGE,
supported_labels=None)
label_calc_mode=LabelCalculationMode.AVERAGE)

doc_ids = set(opinion_storage.iter_column_values(column_name=const.DOC_ID, dtype=int))

Expand Down

0 comments on commit 5cb7ac7

Please sign in to comment.