Skip to content

Commit

Permalink
#240 related refactoring. Separating output onto view, related to opi…
Browse files Browse the repository at this point in the history
…nion linkages.
  • Loading branch information
nicolay-r committed Dec 25, 2021
1 parent 0fe89d3 commit 34de336
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 58 deletions.
13 changes: 0 additions & 13 deletions arekit/common/data/views/base.py

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import pandas as pd

import arekit.common.data.views.linkages.utils as utils

from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views import utils
from arekit.common.data.views.base import BaseStorageView
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.linkage.opinions import OpinionsLinkage


class BaseOutputView(BaseStorageView):
class BaseOpinionLinkagesView(object):
""" Base view onto source in terms of opinion linkages.
"""

def __init__(self, ids_provider, storage):
assert(isinstance(ids_provider, BaseIDProvider))
assert(isinstance(storage, BaseRowsStorage))
super(BaseOutputView, self).__init__(storage=storage)
self._ids_provider = ids_provider
self._storage = storage

# region private methods

Expand All @@ -39,6 +41,7 @@ def _iter_by_opinions(self, linked_df, opinions_view):

# region public methods

# TODO. #240 This is just a wrapper over storage.
def iter_doc_ids(self):
return set(self._storage.iter_column_values(column_name=const.DOC_ID))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

from arekit.common.data import const
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.views import utils
import arekit.common.data.views.linkages.utils as utils
from arekit.common.data.views.linkages.base import BaseOpinionLinkagesView
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.data.views.ouput_base import BaseOutputView
from arekit.common.labels.scaler import BaseLabelScaler


class MulticlassOutputView(BaseOutputView):
class MultilableOpinionLinkagesView(BaseOpinionLinkagesView):
""" View onto sorce, where each row, related to opinion, has multiple labels.
"""

def __init__(self, labels_scaler, storage):
assert(isinstance(labels_scaler, BaseLabelScaler))
super(MulticlassOutputView, self).__init__(ids_provider=MultipleIDProvider(),
storage=storage)
super(MultilableOpinionLinkagesView, self).__init__(ids_provider=MultipleIDProvider(),
storage=storage)
self.__labels_scaler = labels_scaler

# region private methods
Expand Down
File renamed without changes.
8 changes: 6 additions & 2 deletions arekit/common/data/views/opinions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from arekit.common.data import const
from arekit.common.data.views.base import BaseStorageView
from arekit.common.data.storages.base import BaseRowsStorage


class BaseOpinionStorageView(BaseStorageView):
class BaseOpinionStorageView(object):

def __init__(self, storage):
assert(isinstance(storage, BaseRowsStorage))
self._storage = storage

def provide_opinion_info_by_opinion_id(self, opinion_id):
assert(isinstance(opinion_id, str))
Expand Down
37 changes: 17 additions & 20 deletions arekit/common/data/views/samples.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.views.base import BaseStorageView
from arekit.common.data.storages.base import BaseRowsStorage


class BaseSampleStorageView(BaseStorageView):
class BaseSampleStorageView(object):
"""
Pandas-based input samples proovider
"""

def __init__(self, storage, row_ids_provider):
assert(isinstance(row_ids_provider, BaseIDProvider))
super(BaseSampleStorageView, self).__init__(storage)
assert(isinstance(storage, BaseRowsStorage))
self.__row_ids_provider = row_ids_provider
self._storage = storage

# TODO. #240 This is just a wrapper over storage.
def iter_rows(self, handle_rows):
assert(callable(handle_rows) or handle_rows is None)

for row_index, row in self._storage:

if handle_rows is None:
yield row_index, row
else:
yield handle_rows(row)

# TODO. #240 This is just a wrapper over storage.
def extract_ids(self):
return list(self._storage.iter_column_values(column_name=const.ID, dtype=str))

# TODO. #240 This is just a wrapper over storage.
def extract_doc_ids(self):
return list(self._storage.iter_column_values(column_name=const.DOC_ID, dtype=int))

Expand Down Expand Up @@ -45,20 +59,3 @@ def iter_rows_linked_by_text_opinions(self):

if len(linked) > 0:
yield linked

def calculate_doc_id_by_sample_id_dict(self):
"""
Iter sample_ids with the related labels (if the latter presented in dataframe)
"""
doc_id_by_sample_id = {}

for row_index, row in self._storage:

sample_id = row[const.ID]

if sample_id in doc_id_by_sample_id:
continue

doc_id_by_sample_id[sample_id] = row[const.DOC_ID]

return doc_id_by_sample_id
2 changes: 0 additions & 2 deletions arekit/common/experiment/pipelines/opinion_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from arekit.common.experiment.api.ops_opin import OpinionOperations
from arekit.common.experiment.data_type import DataType
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.linkage.base import LinkedDataWrapper
from arekit.common.model.labeling.modes import LabelCalculationMode
Expand All @@ -9,7 +8,6 @@
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.item_iter import FilterPipelineItem
from arekit.common.pipeline.item_map import MapPipelineItem
from arekit.contrib.networks.core.io_utils import NetworkIOUtils


# region private functions
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from os.path import exists, join

from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.engine import ExperimentEngine
Expand Down Expand Up @@ -131,7 +131,7 @@ def _handle_iteration(self, iter_index):

# We utilize google bert format, where every row
# consist of label probabilities per every class
output_view = MulticlassOutputView(
output_view = MultilableOpinionLinkagesView(
labels_scaler=self.__label_scaler,
storage=storage)

Expand Down
12 changes: 6 additions & 6 deletions arekit/contrib/bert/views/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

from arekit.common.data import const
from arekit.common.data.row_ids.base import BaseIDProvider
from arekit.common.data.views import utils
from arekit.common.data.views.linkages import utils
from arekit.common.data.views.linkages.base import BaseOpinionLinkagesView
from arekit.common.data.views.opinions import BaseOpinionStorageView
from arekit.common.data.views.ouput_base import BaseOutputView
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.contrib.bert.input.providers.row_ids_binary import BinaryIDProvider


class BertBinaryOutputView(BaseOutputView):
class BertBinaryOpinionLinkagesView(BaseOpinionLinkagesView):

YES = 'yes'
NO = 'no'

def __init__(self, labels_scaler, storage):
assert(isinstance(labels_scaler, BaseLabelScaler))
super(BertBinaryOutputView, self).__init__(ids_provider=BinaryIDProvider(),
storage=storage)
super(BertBinaryOpinionLinkagesView, self).__init__(ids_provider=BinaryIDProvider(),
storage=storage)
self.__labels_scaler = labels_scaler

# region private methods
Expand All @@ -27,7 +27,7 @@ def __calculate_label(self, df):
Calculate label by relying on a 'YES' column probability values
paper: https://www.aclweb.org/anthology/N19-1035.pdf
"""
ind_max = df[BertBinaryOutputView.YES].idxmax()
ind_max = df[BertBinaryOpinionLinkagesView.YES].idxmax()
sample_id = df.loc[ind_max][const.ID]
uint_label = self._ids_provider.parse_label_in_sample_id(sample_id)
return self.__labels_scaler.uint_to_label(value=uint_label)
Expand Down
24 changes: 21 additions & 3 deletions arekit/contrib/networks/core/callback/utils_model_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from arekit.common.data import const
from arekit.common.data.views.output_multiple import MulticlassOutputView
from arekit.common.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.common.experiment.api.enums import BaseDocumentTag
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.pipelines.opinion_collections import output_to_opinion_collections
Expand Down Expand Up @@ -49,13 +49,13 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
sample_id_with_uint_labels_iter = labeled_samples.iter_non_duplicated_labeled_sample_row_ids()

# TODO. This is a limitation, as we focus only tsv.
doc_id_by_sample_id = samples_view.calculate_doc_id_by_sample_id_dict()
doc_id_by_sample_id = __calculate_doc_id_by_sample_id_dict(samples_view.iter_rows(None))
with TsvPredictProvider(filepath=result_filepath) as out:
out.load(sample_id_with_uint_labels_iter=__log_wrap_samples_iter(sample_id_with_uint_labels_iter),
column_extra_funcs=[(const.DOC_ID, lambda sample_id: doc_id_by_sample_id[sample_id])],
labels_scaler=label_scaler)

output_view = MulticlassOutputView(
output_view = MultilableOpinionLinkagesView(
labels_scaler=label_scaler,
storage=None) # TODO. Pass here the original storage. (NO API for now out there).

Expand Down Expand Up @@ -105,6 +105,24 @@ def evaluate_model(experiment, label_scaler, data_type, epoch_index, model,
return result


def __calculate_doc_id_by_sample_id_dict(rows_iter):
"""
Iter sample_ids with the related labels (if the latter presented in dataframe)
"""
d = {}

for row_index, row in rows_iter:

sample_id = row[const.ID]

if sample_id in d:
continue

d[sample_id] = row[const.DOC_ID]

return d


def __log_wrap_samples_iter(it):
return progress_bar_iter(iterable=it,
desc='Writing output',
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/ctx_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __read_for_data_type(samples_view, is_external_vocab,
input_shapes=input_shapes,
pos_tags=row.PartOfSpeechTags))

rows_it = samples_view.iter_handled_rows(
rows_it = samples_view.iter_rows(
handle_rows=lambda row: InferenceContext.__parse_row(row))

labeled_sample_row_ids = list(rows_it)
Expand Down

0 comments on commit 34de336

Please sign in to comment.