Skip to content

Commit

Permalink
#240 and #242 related
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 25, 2021
1 parent 7555043 commit b044e0c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,19 @@ def _handle_iteration(self, iter_index):

# Writing opinion collection.
save_item = HandleIterPipelineItem(
lambda doc_id, collection:
lambda data:
exp_io.write_opinion_collection(
collection=collection,
collection=data[1],
labels_formatter=self.__labels_formatter,
target=exp_io.create_result_opinion_collection_target(
data_type=self.__data_type,
epoch_index=epoch_index,
doc_id=doc_id)))
doc_id=data[0])))

# Executing pipeline.
ppl.append(save_item)
pipeline_ctx = PipelineContext({
"src": set(storage.iter_column_values(column_name=const.DOC_ID))
"src": set(output_storage.iter_column_values(column_name=const.DOC_ID))
})
ppl.run(pipeline_ctx)

Expand Down
10 changes: 5 additions & 5 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from arekit.common.data import const
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.data_type import DataType
Expand Down Expand Up @@ -59,8 +60,7 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
labels_scaler=label_scaler)
out.write(title=title, contents_it=contents_it)

# TODO. Pass here the original storage. (NO API for now out there).
storage = None
storage = BaseRowsStorage.from_tsv(filepath=result_filepath)

linkages_view = MultilableOpinionLinkagesView(
labels_scaler=label_scaler,
Expand All @@ -79,14 +79,14 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,

# Writing opinion collection.
save_item = HandleIterPipelineItem(
lambda doc_id, collection:
lambda data:
experiment.ExperimentIO.write_opinion_collection(
collection=collection,
collection=data[1],
labels_formatter=labels_formatter,
target=experiment.ExperimentIO.create_result_opinion_collection_target(
data_type=data_type,
epoch_index=epoch_index,
doc_id=doc_id)))
doc_id=data[0])))

# Executing pipeline.
ppl.append(save_item)
Expand Down

0 comments on commit b044e0c

Please sign in to comment.