Skip to content

Commit

Permalink
#476 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed May 18, 2023
1 parent e844ec6 commit 07bc73c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 166 deletions.
89 changes: 89 additions & 0 deletions arekit/contrib/utils/pipelines/items/sampling/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.utils.serializer import InputDataSerializationHelper
from arekit.contrib.utils.utils_folding import folding_iter_states


class BaseSerializerPipelineItem(BasePipelineItem):

def __init__(self, rows_provider, samples_io, save_labels_func, balance_func, storage):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
- two sequences, separated by [SEP] token
save_labels_func: function
data_type -> bool
"""
assert(isinstance(rows_provider, BaseSampleRowProvider))
assert(isinstance(samples_io, BaseSamplesIO))
assert(callable(save_labels_func))
assert(callable(balance_func))
assert(isinstance(storage, BaseRowsStorage))

self._rows_provider = rows_provider
self._balance_func = balance_func
self._samples_io = samples_io
self._save_labels_func = save_labels_func
self._storage = storage

def _serialize_iteration(self, data_type, pipeline, data_folding):
assert (isinstance(data_type, DataType))
assert (isinstance(pipeline, BasePipeline))

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self._save_labels_func(data_type),
rows_provider=self._rows_provider,
storage=self._storage),
}

writer_and_targets = {
"sample": (self._samples_io.Writer,
self._samples_io.create_target(
data_type=data_type, data_folding=data_folding)),
}

for description, repo in repos.items():
InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
doc_ids_iter=data_folding.fold_doc_ids_set()[data_type],
do_balance=self._balance_func(data_type),
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
writer=writer_and_targets[description][0],
target=writer_and_targets[description][1])

def _handle_iteration(self, data_type_pipelines, data_folding):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
assert(isinstance(data_folding, BaseDataFolding))
for data_type, pipeline in data_type_pipelines.items():
self._serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding)

def apply_core(self, input_data, pipeline_ctx):
"""
data_type_pipelines: dict of, for example:
{
DataType.Train: BasePipeline,
DataType.Test: BasePipeline
}
pipeline: doc_id -> parsed_news -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
"""
assert (isinstance(pipeline_ctx, PipelineContext))
assert ("data_type_pipelines" in pipeline_ctx)
assert ("data_folding" in pipeline_ctx)

data_folding = pipeline_ctx.provide("data_folding")
for _ in folding_iter_states(data_folding):
self._handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
data_folding=data_folding)
88 changes: 3 additions & 85 deletions arekit/contrib/utils/pipelines/items/sampling/bert.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,5 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.utils.utils_folding import folding_iter_states
from arekit.contrib.utils.serializer import InputDataSerializationHelper
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem


class BertExperimentInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, rows_provider, samples_io, save_labels_func, balance_func, storage):
""" sample_rows_formatter:
how we format input texts for a BERT model, for example:
- single text
- two sequences, separated by [SEP] token
save_labels_func: function
data_type -> bool
"""
assert(isinstance(samples_io, BaseSamplesIO))
super(BertExperimentInputSerializerPipelineItem, self).__init__()

self.__rows_provider = rows_provider
self.__balance_func = balance_func
self.__samples_io = samples_io
self.__save_labels_func = save_labels_func
self.__storage = storage

# region private methods

def __serialize_iteration(self, data_type, pipeline, data_folding):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=self.__rows_provider,
storage=self.__storage),
}

writer_and_targets = {
"sample": (self.__samples_io.Writer,
self.__samples_io.create_target(
data_type=data_type, data_folding=data_folding)),
}

for description, repo in repos.items():
InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
doc_ids_iter=data_folding.fold_doc_ids_set()[data_type],
do_balance=self.__balance_func(data_type),
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
writer=writer_and_targets[description][0],
target=writer_and_targets[description][1])

def __handle_iteration(self, data_type_pipelines, data_folding):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
assert(isinstance(data_folding, BaseDataFolding))
for data_type, pipeline in data_type_pipelines.items():
self.__serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding)

# endregion

def apply_core(self, input_data, pipeline_ctx=None):
""" data_type_pipelines: dict of, for example:
{
DataType.Train: BasePipeline,
DataType.Test: BasePipeline
}
pipeline: doc_id -> parsed_news -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
"""
assert(isinstance(pipeline_ctx, PipelineContext))
assert("data_type_pipelines" in pipeline_ctx)
assert("data_folding" in pipeline_ctx)

data_folding = pipeline_ctx.provide("data_folding")
for _ in folding_iter_states(data_folding):
self.__handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
data_folding=data_folding)
class BertExperimentInputSerializerPipelineItem(BaseSerializerPipelineItem):
pass
95 changes: 14 additions & 81 deletions arekit/contrib/utils/pipelines/items/sampling/networks.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from arekit.common.experiment.api.base_samples_io import BaseSamplesIO
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.common.pipeline.base import BasePipeline
from arekit.common.pipeline.context import PipelineContext
from arekit.common.pipeline.items.base import BasePipelineItem
from arekit.contrib.networks.input.embedding.matrix import create_term_embedding_matrix
from arekit.contrib.networks.input.embedding.offsets import TermsEmbeddingOffsets
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.networks.input.providers.sample import NetworkSampleRowProvider
from arekit.contrib.utils.io_utils.embedding import NpEmbeddingIO
from arekit.contrib.utils.utils_folding import folding_iter_states
from arekit.contrib.utils.serializer import InputDataSerializationHelper
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem


class NetworksInputSerializerPipelineItem(BasePipelineItem):
class NetworksInputSerializerPipelineItem(BaseSerializerPipelineItem):

def __init__(self, save_labels_func, rows_provider, samples_io,
emb_io, balance_func, storage, save_embedding=True):
Expand All @@ -23,76 +17,37 @@ def __init__(self, save_labels_func, rows_provider, samples_io,
which are supported and required in a handler. It is necessary to know
data_types in advance as it allows to create a complete vocabulary of input terms,
with the related embeddings.
balance: bool
declares whethere there is a need to balance Train samples
save_labels_func: function
data_type -> bool
save_embedding: bool
save embedding and all the related information to it.
"""
assert(isinstance(samples_io, BaseSamplesIO))
assert(isinstance(emb_io, NpEmbeddingIO))
assert(isinstance(rows_provider, NetworkSampleRowProvider))
assert(isinstance(save_embedding, bool))
assert(callable(save_labels_func))
assert(callable(balance_func))
super(NetworksInputSerializerPipelineItem, self).__init__()
super(NetworksInputSerializerPipelineItem, self).__init__(
rows_provider=rows_provider,
samples_io=samples_io,
save_labels_func=save_labels_func,
balance_func=balance_func,
storage=storage)

self.__emb_io = emb_io
self.__samples_io = samples_io
self.__save_embedding = save_embedding
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func
self.__storage = storage
self.__rows_provider = rows_provider

def __serialize_iteration(self, data_type, pipeline, data_folding):
assert(isinstance(data_type, DataType))
assert(isinstance(pipeline, BasePipeline))

repos = {
"sample": InputDataSerializationHelper.create_samples_repo(
keep_labels=self.__save_labels_func(data_type),
rows_provider=self.__rows_provider,
storage=self.__storage),
}

writer_and_targets = {
"sample": (self.__samples_io.Writer,
self.__samples_io.create_target(
data_type=data_type, data_folding=data_folding)),
}

for description, repo in repos.items():
InputDataSerializationHelper.fill_and_write(
repo=repo,
pipeline=pipeline,
doc_ids_iter=data_folding.fold_doc_ids_set()[data_type],
do_balance=self.__balance_func(data_type),
desc="{desc} [{data_type}]".format(desc=description, data_type=data_type),
writer=writer_and_targets[description][0],
target=writer_and_targets[description][1])

def __handle_iteration(self, data_type_pipelines, data_folding):
def _handle_iteration(self, data_type_pipelines, data_folding):
""" Performing data serialization for a particular iteration
"""
assert(isinstance(data_type_pipelines, dict))
assert(isinstance(data_folding, BaseDataFolding))

# Prepare for the present iteration.
self.__rows_provider.clear_embedding_pairs()
self._rows_provider.clear_embedding_pairs()

for data_type, pipeline in data_type_pipelines.items():
self.__serialize_iteration(data_type=data_type, pipeline=pipeline, data_folding=data_folding)
super(NetworksInputSerializerPipelineItem, self)._handle_iteration(
data_type_pipelines=data_type_pipelines, data_folding=data_folding)

if not (self.__save_embedding and self.__rows_provider.HasEmbeddingPairs):
if not (self.__save_embedding and self._rows_provider.HasEmbeddingPairs):
return

# Save embedding information additionally.
term_embedding = Embedding.from_word_embedding_pairs_iter(self.__rows_provider.iter_term_embedding_pairs())
term_embedding = Embedding.from_word_embedding_pairs_iter(self._rows_provider.iter_term_embedding_pairs())
embedding_matrix = create_term_embedding_matrix(term_embedding=term_embedding)
vocab = list(TermsEmbeddingOffsets.extract_vocab(words_embedding=term_embedding))

Expand All @@ -101,25 +56,3 @@ def __handle_iteration(self, data_type_pipelines, data_folding):
self.__emb_io.save_vocab(data=vocab, data_folding=data_folding)

del embedding_matrix

# endregion

def apply_core(self, input_data, pipeline_ctx):
"""
data_type_pipelines: dict of, for example:
{
DataType.Train: BasePipeline,
DataType.Test: BasePipeline
}
pipeline: doc_id -> parsed_news -> annot -> opinion linkages
for example, function: sentiment_attitude_extraction_default_pipeline
"""
assert(isinstance(pipeline_ctx, PipelineContext))
assert("data_type_pipelines" in pipeline_ctx)
assert("data_folding" in pipeline_ctx)

data_folding = pipeline_ctx.provide("data_folding")
for _ in folding_iter_states(data_folding):
self.__handle_iteration(data_type_pipelines=pipeline_ctx.provide("data_type_pipelines"),
data_folding=data_folding)

0 comments on commit 07bc73c

Please sign in to comment.