Skip to content

Commit

Permalink
#240 related refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 25, 2021
1 parent 02b39ac commit 0fe89d3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 22 deletions.
17 changes: 6 additions & 11 deletions arekit/common/experiment/pipelines/opinion_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,23 @@ def __create_and_fill_opinion_collection(opinions_iter, collection, supported_la
# endregion


def output_to_opinion_collections(exp_io, opin_ops, doc_ids_set, labels_scaler, output_view,
data_type, label_calc_mode, supported_labels):
def output_to_opinion_collections(opin_ops, doc_ids_set, labels_scaler,
iter_opinion_linkages_func,
label_calc_mode, supported_labels):
""" Opinion collection generation pipeline.
"""
assert(isinstance(opin_ops, OpinionOperations))
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(isinstance(exp_io, NetworkIOUtils))
assert(isinstance(data_type, DataType))
assert(isinstance(label_calc_mode, LabelCalculationMode))
assert(isinstance(supported_labels, set) or supported_labels is None)
assert(callable(iter_opinion_linkages_func))

# Opinion collections iterator pipeline
ppl = BasePipeline([
return BasePipeline([
FilterPipelineItem(filter_func=lambda doc_id: doc_id in doc_ids_set),

# Iterate opinion linkages.
MapPipelineItem(lambda doc_id:
(doc_id, output_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=exp_io.create_opinions_view(data_type)))),
MapPipelineItem(lambda doc_id: (doc_id, iter_opinion_linkages_func(doc_id))),

# Convert linkages to opinions.
MapPipelineItem(lambda doc_id, linkages_iter:
Expand All @@ -90,5 +87,3 @@ def output_to_opinion_collections(exp_io, opin_ops, doc_ids_set, labels_scaler,
collection=opin_ops.create_opinion_collection(),
supported_labels=supported_labels))),
])

return ppl
8 changes: 4 additions & 4 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ def _handle_iteration(self, iter_index):
storage=storage)

ppl = output_to_opinion_collections(
exp_io=exp_io,
iter_opinion_linkages_func=lambda doc_id: output_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=exp_io.create_opinions_view(self.__data_type)),
doc_ids_set=cmp_doc_ids_set,
opin_ops=self._experiment.OpinionOperations,
labels_scaler=self.__label_scaler,
data_type=self.__data_type,
supported_labels=exp_data.SupportedCollectionLabels,
label_calc_mode=LabelCalculationMode.AVERAGE,
output_view=output_view)
label_calc_mode=LabelCalculationMode.AVERAGE)

# Writing opinion collection.
save_item = HandleIterPipelineItem(
Expand Down
14 changes: 7 additions & 7 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
assert (isinstance(idhp, NetworkInputDependentVariables))

samples_view = experiment.ExperimentIO.create_samples_view(data_type)
doc_id_by_sample_id = samples_view.calculate_doc_id_by_sample_id_dict()

# TODO. Filepath-dependency should be removed!
# Create and save output.
Expand All @@ -50,24 +49,25 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
sample_id_with_uint_labels_iter = labeled_samples.iter_non_duplicated_labeled_sample_row_ids()

# TODO. This is a limitation, as we focus only tsv.
doc_id_by_sample_id = samples_view.calculate_doc_id_by_sample_id_dict()
with TsvPredictProvider(filepath=result_filepath) as out:
out.load(sample_id_with_uint_labels_iter=__log_wrap_samples_iter(sample_id_with_uint_labels_iter),
column_extra_funcs=[(const.DOC_ID, lambda sample_id: doc_id_by_sample_id[sample_id])],
labels_scaler=label_scaler)

output_view = MulticlassOutputView(labels_scaler=label_scaler,
# TODO. Pass here the original storage. (NO API for now out there).
storage=None)
output_view = MulticlassOutputView(
labels_scaler=label_scaler,
storage=None) # TODO. Pass here the original storage. (NO API for now out there).

# Convert output to result.
ppl = output_to_opinion_collections(
iter_opinion_linkages_func=lambda doc_id: output_view.iter_opinion_linkages(
doc_id=doc_id,
opinions_view=experiment.ExperimentIO.create_opinions_view(data_type)),
doc_ids_set=set(experiment.DocumentOperations.iter_tagget_doc_ids(BaseDocumentTag.Compare)),
exp_io=experiment.ExperimentIO,
opin_ops=experiment.OpinionOperations,
labels_scaler=label_scaler,
data_type=data_type,
label_calc_mode=label_calc_mode,
output_view=output_view,
supported_labels=None)

# Writing opinion collection.
Expand Down

0 comments on commit 0fe89d3

Please sign in to comment.