Skip to content

Commit

Permalink
Refactoring output data. Related to #198. Using views instead of form…
Browse files Browse the repository at this point in the history
…atters
  • Loading branch information
nicolay-r committed Oct 2, 2021
1 parent 6c44dcc commit 3050956
Show file tree
Hide file tree
Showing 24 changed files with 96 additions and 144 deletions.
2 changes: 1 addition & 1 deletion arekit/common/experiment/input/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from arekit.common.experiment.input.providers.columns.base import BaseColumnsProvider
from arekit.common.experiment.input.providers.opinions import OpinionProvider
from arekit.common.experiment.input.providers.rows.base import BaseRowProvider
from arekit.common.experiment.input.storages.base import BaseRowsStorage
from arekit.common.experiment.storages.base import BaseRowsStorage


class BaseInputRepository(object):
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/experiment/input/views/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.common.experiment.input.storages.base import BaseRowsStorage
from arekit.common.experiment.storages.base import BaseRowsStorage


class BaseStorageView(object):
Expand Down
2 changes: 1 addition & 1 deletion arekit/common/experiment/input/writers/tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from arekit.common.experiment import const
from arekit.common.experiment.input.providers.columns.base import BaseColumnsProvider
from arekit.common.experiment.input.storages.base import BaseRowsStorage
from arekit.common.experiment.input.writers.base import BaseWriter
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.utils import create_dir_if_not_exists

logger = logging.getLogger(__name__)
Expand Down
16 changes: 13 additions & 3 deletions arekit/common/experiment/output/opinions/converter.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,38 @@
from arekit.common.experiment.input.views.opinions import BaseOpinionStorageView
from arekit.common.experiment.output.formatters.base import BaseOutputFormatter
from arekit.common.experiment.output.utils import fill_opinion_collection
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.experiment.output.views.base import BaseOutputView
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.model.labeling.single import SingleLabelsHelper
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.opinions.base import Opinion


class OutputToOpinionCollectionsConverter(object):

# TODO. To output_view. Provide opinions_iter for collection organization only!
@staticmethod
def iter_opinion_collections(opinions_view,
# TODO. Remove
labels_scaler,
keep_doc_id_func,
# TODO. Collection???
create_opinion_collection_func,
# TODO. Remove
label_calculation_mode,
# TODO. Remove
supported_labels,
output_formatter):
assert(callable(keep_doc_id_func))
# TODO. Only for collection filling (remove)
assert(isinstance(labels_scaler, BaseLabelScaler))
assert(isinstance(opinions_view, BaseOpinionStorageView))
# TODO. Collection???
assert(callable(create_opinion_collection_func))
# TODO. Only for collection filling (remove)
assert(isinstance(label_calculation_mode, LabelCalculationMode))
# TODO. Only for collection filling (remove)
assert(isinstance(supported_labels, set) or supported_labels is None)
assert(isinstance(output_formatter, BaseOutputFormatter))
assert(isinstance(output_formatter, BaseOutputView))

labels_helper = SingleLabelsHelper(labels_scaler)

Expand All @@ -32,6 +41,7 @@ def iter_opinion_collections(opinions_view,
if not keep_doc_id_func(news_id):
continue

# TODO. Collection???
collection = create_opinion_collection_func()

linked_iter = output_formatter.iter_linked_opinions(news_id=news_id,
Expand Down
12 changes: 0 additions & 12 deletions arekit/common/experiment/output/providers/base.py

This file was deleted.

32 changes: 0 additions & 32 deletions arekit/common/experiment/output/providers/tsv.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
import pandas as pd

from arekit.common.experiment import const
from arekit.common.experiment.input.views.base import BaseStorageView
from arekit.common.experiment.input.views.opinions import BaseOpinionStorageView
from arekit.common.experiment.output.providers.base import BaseOutputProvider
from arekit.common.experiment.row_ids.base import BaseIDProvider
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.linked.opinions.wrapper import LinkedOpinionWrapper
from arekit.common.opinions.base import Opinion


class BaseOutputFormatter(object):
class BaseOutputView(BaseStorageView):
""" Results output represents a table, which stored in pandas dataframe.
This dataframe assumes to provide the following columns:
- id -- is a row identifier, which is compatible with row_inds in serialized opinions.
- news_id -- is a related news_id towards which the related output corresponds to.
- labels -- uint labels (amount of columns depends on the scaler)
"""

def __init__(self, ids_provider, output_provider):
def __init__(self, ids_provider, storage):
assert(isinstance(ids_provider, BaseIDProvider))
assert(isinstance(output_provider, BaseOutputProvider))
assert(isinstance(storage, BaseRowsStorage))
super(BaseOutputView, self).__init__(storage=storage)
self._ids_provider = ids_provider
self.__provider = output_provider

# region properties

@property
def _df(self):
return self.__provider.DataFrame

# endregion

# region private methods

def __iter_linked_opinions_df(self, news_id):
assert(isinstance(news_id, int))

# TODO. 206 from storage. (filter by column API)
news_df = self._df[self._df[const.NEWS_ID] == news_id]
# TODO. Proceed refactoring
news_df = self._storage.find_by_value(column_name=const.NEWS_ID,
value=news_id)

opinion_ids = [self._ids_provider.parse_opinion_in_opinion_id(opinion_id)
for opinion_id in news_df[const.ID]]

Expand Down Expand Up @@ -73,8 +68,8 @@ def _compose_opinion_by_opinion_id(self, sample_id, opinions_view, calc_label_fu
# region public methods

def iter_news_ids(self):
assert (const.NEWS_ID in self._df.columns)
return set(self._df[const.NEWS_ID])
unique_news_ids = set(self._storage.iter_column_values(column_name=const.NEWS_ID))
return unique_news_ids

def iter_linked_opinions(self, news_id, opinions_view):
assert (isinstance(news_id, int))
Expand All @@ -88,10 +83,4 @@ def iter_linked_opinions(self, news_id, opinions_view):

yield LinkedOpinionWrapper(linked_data=opinions_iter)

def load(self, source):
self.__provider.load(source)

# endregion

def __len__(self):
return len(self._df.index)
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

from arekit.common.experiment import const
from arekit.common.experiment.input.views.opinions import BaseOpinionStorageView
from arekit.common.experiment.output.formatters.base import BaseOutputFormatter
from arekit.common.experiment.output.views.base import BaseOutputView
from arekit.common.experiment.row_ids.multiple import MultipleIDProvider
from arekit.common.labels.scaler import BaseLabelScaler


class MulticlassOutputFormatter(BaseOutputFormatter):
class MulticlassOutputView(BaseOutputView):

def __init__(self, labels_scaler, output_provider):
def __init__(self, labels_scaler, storage):
assert(isinstance(labels_scaler, BaseLabelScaler))
super(MulticlassOutputFormatter, self).__init__(ids_provider=MultipleIDProvider(),
output_provider=output_provider)
super(MulticlassOutputView, self).__init__(ids_provider=MultipleIDProvider(),
storage=storage)
self.__labels_scaler = labels_scaler

# region private methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def DataFrame(self):
# endregion

@classmethod
def from_tsv(cls, filepath, sep='\t', compression='gzip', encoding='utf-8'):
def from_tsv(cls, filepath, sep='\t', compression='gzip', encoding='utf-8', header="infer"):
instance = cls()
instance._df = pd.read_csv(filepath,
sep=sep,
encoding=encoding,
compression=compression)
compression=compression,
header=header)
return instance

# region private methods
Expand All @@ -55,6 +56,9 @@ def __set_value(self, row_ind, column, value):
def __log_info(self):
logger.info(self._df.info())

def __filter(self, column_name, value):
return self._df[self._df[column_name] == value]

# endregion

# region protected methods
Expand All @@ -81,8 +85,13 @@ def _balance(self, column_name):

# region public methods

def find_by_value(self, column_name, value):
# TODO. Return new storage. (Encapsulation)
return self.__filter(column_name=column_name, value=value)

def find_first_by_value(self, column_name, value):
row = self._df[self._df[column_name] == value]
# TODO. Return new storage. (Encapsulation)
row = self.__filter(column_name=column_name, value=value)
return row.iloc[0]

def iter_column_values(self, column_name, dtype=None):
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/bert/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.data_type import DataType
from arekit.common.experiment.input.storages.base import BaseRowsStorage
from arekit.common.experiment.input.views.opinions import BaseOpinionStorageView
from arekit.common.experiment.input.views.samples import BaseSampleStorageView
from arekit.common.experiment.input.writers.tsv import TsvWriter
from arekit.common.experiment.row_ids.multiple import MultipleIDProvider
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.utils import join_dir_with_subfolder_name


Expand Down
37 changes: 16 additions & 21 deletions arekit/contrib/bert/output/google_bert_provider.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from arekit.common.experiment import const
from arekit.common.experiment.output.providers.tsv import TsvBaseOutputProvider
from arekit.common.experiment.storages.base import BaseRowsStorage


class GoogleBertOutputProvider(TsvBaseOutputProvider):
""" This output assumes to be provided with only labels
by default proposed here:
class GoogleBertOutputStorage(BaseRowsStorage):
""" This output assumes to be provided with only labels by default proposed here:
https://github.com/google-research/bert
In addition to such output we provide the following parameters via samples_view instance:
- id -- is a row identifier, which is compatible with row_inds in serialized opinions.
- news_id -- is a related news_id towards which the related output corresponds to.
"""

def __init__(self, samples_view, has_output_header):
super(GoogleBertOutputProvider, self).__init__(has_output_header=has_output_header)
self.__samples_view = samples_view

# region protected methods

def _csv_to_dataframe(self, filepath):
df = super(GoogleBertOutputProvider, self)._csv_to_dataframe(filepath=filepath)
def apply_samples_view(self, samples_view):
"""
In addition to such output we provide the following parameters via samples_view instance:
- id -- is a row identifier, which is compatible with row_inds in serialized opinions.
- news_id -- is a related news_id towards which the related output corresponds to.
"""
row_ids = samples_view.extract_ids()
news_ids = samples_view.extract_news_ids()

# Exporting such information from samples.
row_ids = self.__samples_view.extract_ids()
news_ids = self.__samples_view.extract_news_ids()
assert(len(row_ids) == len(news_ids) == len(self.DataFrame))

assert(len(row_ids) == len(news_ids) == len(df))
df = self.DataFrame

# Providing the latter into output.
df.insert(0, const.ID, row_ids)
Expand All @@ -36,6 +29,8 @@ def _csv_to_dataframe(self, filepath):

df.columns = [str(c) for c in df.columns]

return df
@classmethod
def from_tsv(cls, filepath, sep='\t', compression='infer', encoding='utf-8', header=None):
return super(GoogleBertOutputStorage, cls).from_tsv(filepath=filepath, sep=sep, compression=compression, encoding=encoding, header=header)

# endregion
Empty file.
15 changes: 8 additions & 7 deletions arekit/contrib/bert/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from arekit.common.experiment.api.ctx_training import TrainingData
from arekit.common.experiment.engine import ExperimentEngine
from arekit.common.experiment.output.formatters.multiple import MulticlassOutputFormatter
from arekit.common.experiment.output.opinions.converter import OutputToOpinionCollectionsConverter
from arekit.common.experiment.output.views.multiple import MulticlassOutputView
from arekit.common.labels.scaler import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.model.labeling.modes import LabelCalculationMode
from arekit.common.utils import join_dir_with_subfolder_name
from arekit.contrib.bert.callback import Callback
from arekit.contrib.bert.output.eval_helper import EvalHelper
from arekit.contrib.bert.output.google_bert_provider import GoogleBertOutputProvider
from arekit.contrib.bert.output.google_bert_provider import GoogleBertOutputStorage

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -108,14 +108,15 @@ def _handle_iteration(self, iter_index):
self._log_info("\nStarting evaluation for: {}".format(result_filepath),
forced=True)

# Initialize storage.
storage = GoogleBertOutputStorage.from_tsv(filepath=result_filepath, header=None)
storage.apply_samples_view(samples_view=exp_io.create_samples_view(self.__data_type))

# We utilize google bert format, where every row
# consist of label probabilities per every class
output = MulticlassOutputFormatter(
output = MulticlassOutputView(
labels_scaler=self.__label_scaler,
output_provider=GoogleBertOutputProvider(
samples_view=exp_io.create_samples_view(self.__data_type),
has_output_header=False))
output.load(source=result_filepath)
storage=storage)

# iterate opinion collections.
collections_iter = OutputToOpinionCollectionsConverter.iter_opinion_collections(
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/bert/run_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from arekit.common.experiment.input.providers.rows.opinions import BaseOpinionsRowProvider
from arekit.common.experiment.input.repositories.opinions import BaseInputOpinionsRepository
from arekit.common.experiment.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.experiment.input.storages.base import BaseRowsStorage
from arekit.common.experiment.storages.base import BaseRowsStorage
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.contrib.bert.samplers.factory import create_bert_sample_provider

Expand Down
File renamed without changes.
Loading

0 comments on commit 3050956

Please sign in to comment.