Skip to content

Commit

Permalink
#282 related. Removing usage and serialization of opinions
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Aug 1, 2022
1 parent 0b458eb commit 5db2391
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 112 deletions.
20 changes: 20 additions & 0 deletions arekit/common/experiment/api/base_samples_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class BaseSamplesIO(object):
""" Represents base experiment utils for input/output for:
samples -- data that utilized for experiments;
results -- evaluation of experiments.
"""

def create_view(self, data_type, data_folding):
""" For viewing/reading
"""
raise NotImplementedError()

def create_writer(self):
""" For serialization
"""
raise NotImplementedError()

def create_target(self, data_type, data_folding):
""" Path for reaiding/viewing
"""
raise NotImplementedError()
31 changes: 0 additions & 31 deletions arekit/common/experiment/api/io_utils.py

This file was deleted.

21 changes: 6 additions & 15 deletions arekit/contrib/bert/pipelines/items/serializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.context import PipelineContext
Expand All @@ -9,8 +9,7 @@

class BertExperimentInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, sample_rows_provider, exp_io, save_labels_func,
balance_func, keep_opinions_repo=False):
def __init__(self, sample_rows_provider, samples_io, save_labels_func, balance_func):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
Expand All @@ -19,14 +18,13 @@ def __init__(self, sample_rows_provider, exp_io, save_labels_func,
save_labels_func: function
data_type -> bool
"""
assert(isinstance(exp_io, BaseIOUtils))
assert(isinstance(samples_io, BaseSamplesIO))
super(BertExperimentInputSerializerPipelineItem, self).__init__()

self.__sample_rows_provider = sample_rows_provider
self.__balance_func = balance_func
self.__exp_io = exp_io
self.__samples_io = samples_io
self.__save_labels_func = save_labels_func
self.__keep_opinions_repo = keep_opinions_repo

# region private methods

Expand All @@ -37,23 +35,16 @@ def __serialize_iteration(self, data_type, pipeline, data_folding):
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=self.__sample_rows_provider),
"opinion": InputDataSerializationHelper.create_opinion_repo()
}

writer_and_targets = {
"sample": (self.__exp_io.create_samples_writer(),
self.__exp_io.create_samples_writer_target(
"sample": (self.__samples_io.create_writer(),
self.__samples_io.create_target(
data_type=data_type, data_folding=data_folding)),
"opinion": (self.__exp_io.create_opinions_writer(),
self.__exp_io.create_opinions_writer_target(
data_type=data_type, data_folding=data_folding))
}

for description, repo in repos.items():

if description == "opinion" and not self.__keep_opinions_repo:
continue

InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
Expand Down
2 changes: 1 addition & 1 deletion arekit/contrib/networks/core/embedding_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class BaseEmbeddingIOUtils(object):
class BaseEmbeddingIO(object):
""" API for loading and saving embedding and vocabulary related data.
"""

Expand Down
15 changes: 3 additions & 12 deletions arekit/contrib/networks/pipelines/items/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class NetworksInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, vectorizers, save_labels_func, str_entity_fmt, exp_ctx,
samples_io, emb_io, balance_func, save_embedding, keep_opinions_repos=False):
samples_io, emb_io, balance_func, save_embedding):
""" This pipeline item allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand Down Expand Up @@ -64,7 +64,6 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, exp_ctx,
self.__save_embedding = save_embedding
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func
self.__keep_opinions_repo = keep_opinions_repos

self.__term_embedding_pairs = collections.OrderedDict()

Expand Down Expand Up @@ -100,23 +99,15 @@ def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=rows_provider),
"opinion": InputDataSerializationHelper.create_opinion_repo()
}

writer_and_targets = {
"sample": (self.__samples_io.create_samples_writer(),
self.__samples_io.create_samples_writer_target(
"sample": (self.__samples_io.create_writer(),
self.__samples_io.create_target(
data_type=data_type, data_folding=data_folding)),
"opinion": (self.__samples_io.create_opinions_writer(),
self.__samples_io.create_opinions_writer_target(
data_type=data_type, data_folding=data_folding))
}

for description, repo in repos.items():

if description == "opinion" and not self.__keep_opinions_repo:
continue

InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/networks/pipelines/items/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __handle_iteration(self, data_folding, data_type):
inference_ctx = InferenceContext.create_empty()
inference_ctx.initialize(
dtypes=data_folding.iter_supported_data_types(),
create_samples_view_func=lambda data_type: self.__samples_io.create_samples_view(
data_type=data_type, data_folding=data_folding),
create_samples_view_func=lambda data_type: self.__samples_io.create_view(
self.__samples_io.create_target(data_type=data_type, data_folding=data_folding)),
has_model_predefined_state=self.__model_io.IsPretrainedStateProvided,
labels_count=self.__labels_count,
vocab=self.__emb_io.load_vocab(data_folding),
Expand Down
4 changes: 2 additions & 2 deletions arekit/contrib/utils/io_utils/embedding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from os.path import join

from arekit.common.folding.base import BaseDataFolding
from arekit.contrib.networks.core.embedding_io import BaseEmbeddingIOUtils
from arekit.contrib.networks.core.embedding_io import BaseEmbeddingIO
from arekit.contrib.utils.io_utils.utils import check_targets_existence
from arekit.contrib.utils.np_utils.embedding import NpzEmbeddingHelper
from arekit.contrib.utils.utils_folding import experiment_iter_index


class NpzEmbeddingIOUtils(BaseEmbeddingIOUtils):
class NpzEmbeddingIOUtils(BaseEmbeddingIO):
""" Npz-based IO utils for vocabulary and embedding.
This format represents a archived version of the numpy math data, i.e. vectors, numbers, etc.
Expand Down
34 changes: 34 additions & 0 deletions arekit/contrib/utils/io_utils/opinions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from os.path import join

from arekit.common.data.storages.base import BaseRowsStorage
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.io_utils.utils import filename_template


class OpinionsIOUtils(object):

def __init__(self, target_dir, target_extension=".tsv.gz"):
self.__target_dir = target_dir
self.__target_extension = target_extension

def create_view(self, target):
storage = BaseRowsStorage.from_tsv(filepath=target)
return BaseOpinionStorageView(storage)

def create_writer_target(self, data_type, data_folding):
return self.__get_input_opinions_target(data_type, data_folding=data_folding)

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="opinion",
extension=self.__target_extension)

@staticmethod
def __get_filepath(out_dir, template, prefix, extension):
assert(isinstance(template, str))
assert(isinstance(prefix, str))
assert(isinstance(extension, str))
return join(out_dir, "{prefix}-{template}{extension}".format(
prefix=prefix, template=template, extension=extension))
43 changes: 9 additions & 34 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.data_type import DataType
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.io_utils.utils import filename_template, check_targets_existence

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class SamplesIOUtils(BaseIOUtils):
class SamplesIOUtils(BaseSamplesIO):
""" Samples default IO utils for samples.
Sample is a text part which include pair of attitude participants.
This class allows to provide saver and loader for such entries, bubbed as samples.
Expand All @@ -24,46 +21,31 @@ class SamplesIOUtils(BaseIOUtils):

def __init__(self, target_dir,
samples_writer=TsvWriter(write_header=True),
prefix="sample",
target_extension=".tsv.gz"):
assert(isinstance(samples_writer, BaseWriter))
self.__target_dir = target_dir
self.__samples_writer = samples_writer
self.__target_extension = target_extension
self.__prefix = prefix

# region public methods

def create_samples_view(self, data_type, data_folding):
assert(isinstance(data_type, DataType))
storage = BaseRowsStorage.from_tsv(
filepath=self.__get_input_sample_target(data_type=data_type, data_folding=data_folding))
return BaseSampleStorageView(storage=storage,
def create_view(self, target):
return BaseSampleStorageView(storage=BaseRowsStorage.from_tsv(filepath=target),
row_ids_provider=MultipleIDProvider())

def create_opinions_view(self, target):
storage = BaseRowsStorage.from_tsv(filepath=target)
return BaseOpinionStorageView(storage)

def create_opinions_writer(self):
return self.__samples_writer

def create_samples_writer(self):
def create_writer(self):
return self.__samples_writer

def create_target_extension(self):
return self.__target_extension

def create_opinions_writer_target(self, data_type, data_folding):
return self.__get_input_opinions_target(data_type, data_folding=data_folding)

def create_samples_writer_target(self, data_type, data_folding):
def create_target(self, data_type, data_folding):
return self.__get_input_sample_target(data_type, data_folding=data_folding)

def check_targets_existed(self, data_types_iter, data_folding):
for data_type in data_types_iter:

targets = [
self.__get_input_sample_target(data_type=data_type, data_folding=data_folding),
# self.__get_input_opinions_target(data_type=data_type, data_folding=data_folding),
]

if not check_targets_existence(targets=targets):
Expand All @@ -72,19 +54,12 @@ def check_targets_existed(self, data_types_iter, data_folding):

# endregion

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="opinion",
extension=self.create_target_extension())

def __get_input_sample_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="sample",
extension=self.create_target_extension())
prefix=self.__prefix,
extension=self.__target_extension)

# endregion

Expand Down
10 changes: 5 additions & 5 deletions arekit/contrib/utils/pipelines/items/to_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.common.pipeline.items.handle import HandleIterPipelineItem
from arekit.contrib.utils.data.views.linkages.multilabel import MultilableOpinionLinkagesView
from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.io_utils.opinions import OpinionsIOUtils
from arekit.contrib.utils.utils_folding import folding_iter_states, experiment_iter_index
from arekit.contrib.utils.pipelines.opinion_collections import \
text_opinion_linkages_to_opinion_collections_pipeline_part
Expand All @@ -24,7 +24,7 @@ def __init__(self, opinion_samples_io, create_opinion_collection_func,
""" create_opinion_collection_func: func
func () -> OpinionCollection (empty)
"""
assert(isinstance(opinion_samples_io, SamplesIOUtils))
assert(isinstance(opinion_samples_io, OpinionsIOUtils))
assert(callable(create_opinion_collection_func))
assert(isinstance(label_scaler, BaseLabelScaler))
assert(isinstance(labels_formatter, StringLabelsFormatter))
Expand Down Expand Up @@ -52,12 +52,12 @@ def __convert(self, data_folding, output_storage, target_func, data_type):
linkages_view = MultilableOpinionLinkagesView(labels_scaler=self.__label_scaler,
storage=output_storage)

target = self.__opinion_samples_io.create_opinions_writer_target(data_type=data_type,
data_folding=data_folding)
target = self.__opinion_samples_io.create_writer_target(data_type=data_type,
data_folding=data_folding)

converter_part = text_opinion_linkages_to_opinion_collections_pipeline_part(
iter_opinion_linkages_func=lambda doc_id: linkages_view.iter_opinion_linkages(
doc_id=doc_id, opinions_view=self.__opinion_samples_io.create_opinions_view(target)),
doc_id=doc_id, opinions_view=self.__opinion_samples_io.create_view(target)),
doc_ids_set=set(data_folding.fold_doc_ids_set()[data_type]),
create_opinion_collection_func=self.__create_opinion_collection_func,
labels_scaler=self.__label_scaler,
Expand Down
10 changes: 0 additions & 10 deletions arekit/contrib/utils/serializer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import collections
import logging

from arekit.common.data.input.providers.columns.opinion import OpinionColumnsProvider
from arekit.common.data.input.providers.columns.sample import SampleColumnsProvider
from arekit.common.data.input.providers.opinions import InputTextOpinionProvider
from arekit.common.data.input.providers.rows.base import BaseRowProvider
from arekit.common.data.input.providers.rows.opinions import BaseOpinionsRowProvider
from arekit.common.data.input.repositories.base import BaseInputRepository
from arekit.common.data.input.repositories.opinions import BaseInputOpinionsRepository
from arekit.common.data.input.repositories.sample import BaseInputSamplesRepository
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.pipeline.base import BasePipeline
Expand All @@ -27,13 +24,6 @@ def create_samples_repo(keep_labels, rows_provider):
rows_provider=rows_provider,
storage=BaseRowsStorage())

@staticmethod
def create_opinion_repo():
return BaseInputOpinionsRepository(
columns_provider=OpinionColumnsProvider(),
rows_provider=BaseOpinionsRowProvider(),
storage=BaseRowsStorage())

@staticmethod
def fill_and_write(pipeline, repo, target, writer, doc_ids_iter, desc="", do_balance=False):
assert(isinstance(pipeline, BasePipeline))
Expand Down

0 comments on commit 5db2391

Please sign in to comment.