Skip to content

Commit

Permalink
#240 refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 25, 2021
1 parent 34de336 commit cc4294e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
4 changes: 0 additions & 4 deletions arekit/common/data/views/linkages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ def _iter_by_opinions(self, linked_df, opinions_view):

# region public methods

# TODO. #240 This is just a wrapper over storage.
def iter_doc_ids(self):
return set(self._storage.iter_column_values(column_name=const.DOC_ID))

def iter_opinion_linkages(self, doc_id, opinions_view):
assert(isinstance(opinions_view, BaseOpinionStorageView))
doc_df = self._storage.find_by_value(column_name=const.DOC_ID, value=doc_id)
Expand Down
7 changes: 4 additions & 3 deletions arekit/common/data/views/linkages/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.opinions.base import Opinion
Expand All @@ -10,10 +11,10 @@ def compose_opinion_by_opinion_id(ids_provider, sample_id, opinions_view, calc_l
assert(callable(calc_label_func))

opinion_id = ids_provider.convert_sample_id_to_opinion_id(sample_id=sample_id)
source, target = opinions_view.provide_opinion_info_by_opinion_id(opinion_id=opinion_id)
row = opinions_view.row_by_id(opinion_id=opinion_id)

return Opinion(source_value=source,
target_value=target,
return Opinion(source_value=row[const.SOURCE],
target_value=row[const.TARGET],
sentiment=calc_label_func())


Expand Down
12 changes: 3 additions & 9 deletions arekit/common/data/views/opinions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@ def __init__(self, storage):
assert(isinstance(storage, BaseRowsStorage))
self._storage = storage

def provide_opinion_info_by_opinion_id(self, opinion_id):
def row_by_id(self, opinion_id):
assert(isinstance(opinion_id, str))

row = self._storage.find_first_by_value(column_name=const.ID,
value=opinion_id)

source = row[const.SOURCE]
target = row[const.TARGET]

return source, target
return self._storage.find_first_by_value(column_name=const.ID,
value=opinion_id)
5 changes: 4 additions & 1 deletion arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from os.path import exists, join

from arekit.common.data import const
from arekit.common.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.api.enums import BaseDocumentTag
Expand Down Expand Up @@ -158,7 +159,9 @@ def _handle_iteration(self, iter_index):

# Executing pipeline.
ppl.append(save_item)
pipeline_ctx = PipelineContext({"src": output_view.iter_doc_ids()})
pipeline_ctx = PipelineContext({
"src": set(storage.iter_column_values(column_name=const.DOC_ID))
})
ppl.run(pipeline_ctx)

# iterate over the result.
Expand Down
9 changes: 7 additions & 2 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
column_extra_funcs=[(const.DOC_ID, lambda sample_id: doc_id_by_sample_id[sample_id])],
labels_scaler=label_scaler)

# TODO. Pass here the original storage. (NO API for now out there).
storage = None

output_view = MultilableOpinionLinkagesView(
labels_scaler=label_scaler,
storage=None) # TODO. Pass here the original storage. (NO API for now out there).
storage=storage)

# Convert output to result.
ppl = output_to_opinion_collections(
Expand All @@ -83,7 +86,9 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,

# Executing pipeline.
ppl.append(save_item)
pipeline_ctx = PipelineContext({"src": output_view.iter_doc_ids()})
pipeline_ctx = PipelineContext({
"src": set(storage.iter_column_values(column_name=const.DOC_ID))
})
ppl.run(pipeline_ctx)

# iterate over the result.
Expand Down

0 comments on commit cc4294e

Please sign in to comment.