Skip to content

Commit

Permalink
#282 related refactoring. Adopt embedding io and sample based instead…
Browse files Browse the repository at this point in the history
… of model related one
  • Loading branch information
nicolay-r committed Jul 31, 2022
1 parent e6a98ae commit c4f942e
Show file tree
Hide file tree
Showing 9 changed files with 295 additions and 374 deletions.
10 changes: 5 additions & 5 deletions arekit/contrib/experiment_rusentrel/pipelines/items/bert_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
from arekit.contrib.experiment_rusentrel.bert.output_provider import GoogleBertOutputStorage
from arekit.contrib.experiment_rusentrel.eval_helper import EvalHelper
from arekit.contrib.experiment_rusentrel.utils import create_result_opinion_collection_target
from arekit.contrib.utils.io_utils.bert import DefaultBertIOUtils
from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.pipelines.items.to_output import TextOpinionLinkagesToOpinionConverterPipelineItem


class ModelEvaluationPipelineItem(TextOpinionLinkagesToOpinionConverterPipelineItem):

def __init__(self, exp_io, eval_helper, create_opinion_collection_func,
def __init__(self, samples_io, eval_helper, create_opinion_collection_func,
original_target_dir, output_target_dir, max_epochs_count, label_scaler,
labels_formatter, iteration_index, opinion_collection_writer):
assert(isinstance(exp_io, DefaultBertIOUtils))
assert(isinstance(samples_io, SamplesIOUtils))
assert(isinstance(output_target_dir, str))
assert(isinstance(iteration_index, int))

super(ModelEvaluationPipelineItem, self).__init__(
exp_io=exp_io, create_opinion_collection_func=create_opinion_collection_func,
opinion_samples_io=samples_io, create_opinion_collection_func=create_opinion_collection_func,
opinion_collection_writer=opinion_collection_writer,
label_scaler=label_scaler, labels_formatter=labels_formatter)

Expand Down Expand Up @@ -80,7 +80,7 @@ def __get_output_storage(epoch_index, iter_index, eval_helper, original_target_d

def apply_core(self, input_data, pipeline_ctx):

if not self.__exp_io.check_targets_existed():
if not self.__opinion_samples_io.check_targets_existed():
return

super(ModelEvaluationPipelineItem, self).apply_core(
Expand Down
24 changes: 14 additions & 10 deletions arekit/contrib/networks/pipelines/items/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
from arekit.contrib.networks.core.input.providers.text import NetworkSingleTextProvider
from arekit.contrib.networks.core.input.terms_mapping import StringWithEmbeddingNetworkTermMapping
from arekit.contrib.networks.embedding import Embedding
from arekit.contrib.utils.io_utils.tf_networks import DefaultNetworkIOUtils
from arekit.contrib.utils.io_utils.embedding import NpzEmbeddingIOUtils

from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.utils_folding import folding_iter_states
from arekit.contrib.utils.serializer import InputDataSerializationHelper


class NetworksInputSerializerPipelineItem(BasePipelineItem):

def __init__(self, vectorizers, save_labels_func, str_entity_fmt, exp_ctx,
exp_io, balance_func, save_embedding, keep_opinions_repos=False):
samples_io, emb_io, balance_func, save_embedding, keep_opinions_repos=False):
""" This pipeline item allows to perform a data preparation for neural network models.
considering a list of the whole data_types with the related pipelines,
Expand All @@ -48,15 +50,17 @@ def __init__(self, vectorizers, save_labels_func, str_entity_fmt, exp_ctx,
save embedding and all the related information to it.
"""
assert(isinstance(exp_ctx, NetworkSerializationContext))
assert(isinstance(exp_io, DefaultNetworkIOUtils))
assert(isinstance(samples_io, SamplesIOUtils))
assert(isinstance(emb_io, NpzEmbeddingIOUtils))
assert(isinstance(str_entity_fmt, StringEntitiesFormatter))
assert(isinstance(vectorizers, dict))
assert(isinstance(save_embedding, bool))
assert(callable(save_labels_func))
assert(callable(balance_func))
super(NetworksInputSerializerPipelineItem, self).__init__()

self.__exp_io = exp_io
self.__emb_io = emb_io
self.__samples_io = samples_io
self.__save_embedding = save_embedding
self.__save_labels_func = save_labels_func
self.__balance_func = balance_func
Expand Down Expand Up @@ -100,11 +104,11 @@ def __serialize_iteration(self, data_type, pipeline, rows_provider, data_folding
}

writer_and_targets = {
"sample": (self.__exp_io.create_samples_writer(),
self.__exp_io.create_samples_writer_target(
"sample": (self.__samples_io.create_samples_writer(),
self.__samples_io.create_samples_writer_target(
data_type=data_type, data_folding=data_folding)),
"opinion": (self.__exp_io.create_opinions_writer(),
self.__exp_io.create_opinions_writer_target(
"opinion": (self.__samples_io.create_opinions_writer(),
self.__samples_io.create_opinions_writer_target(
data_type=data_type, data_folding=data_folding))
}

Expand Down Expand Up @@ -146,8 +150,8 @@ def __handle_iteration(self, data_type_pipelines, data_folding):
vocab = list(TermsEmbeddingOffsets.extract_vocab(words_embedding=term_embedding))

# Save embedding matrix
self.__exp_io.save_embedding(data=embedding_matrix, data_folding=data_folding)
self.__exp_io.save_vocab(data=vocab, data_folding=data_folding)
self.__emb_io.save_embedding(data=embedding_matrix, data_folding=data_folding)
self.__emb_io.save_vocab(data=vocab, data_folding=data_folding)

del embedding_matrix

Expand Down
31 changes: 23 additions & 8 deletions arekit/contrib/networks/pipelines/items/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
from arekit.contrib.networks.core.pipeline.item_predict_labeling import EpochLabelsCollectorPipelineItem
from arekit.contrib.networks.shapes import NetworkInputShapes
from arekit.contrib.networks.utils import rm_dir_contents
from arekit.contrib.utils.io_utils.tf_networks import DefaultNetworkIOUtils
from arekit.contrib.utils.io_utils.embedding import NpzEmbeddingIOUtils
from arekit.contrib.utils.io_utils.samples import SamplesIOUtils
from arekit.contrib.utils.utils_folding import folding_iter_states


class NetworksTrainingPipelineItem(BasePipelineItem):

def __init__(self, bags_collection_type, model_io, exp_io,
def __init__(self, bags_collection_type, model_io, samples_io, emb_io,
load_model, config, create_network_func, training_epochs,
labels_count, network_callbacks, prepare_model_root=True, seed=None):
assert(callable(create_network_func))
assert(isinstance(exp_io, DefaultNetworkIOUtils))
assert(isinstance(samples_io, SamplesIOUtils))
assert(isinstance(emb_io, NpzEmbeddingIOUtils))
assert(isinstance(config, DefaultNetworkConfig))
assert(issubclass(bags_collection_type, BagsCollection))
assert(isinstance(load_model, bool))
Expand All @@ -40,7 +42,8 @@ def __init__(self, bags_collection_type, model_io, exp_io,
super(NetworksTrainingPipelineItem, self).__init__()

self.__logger = self.__create_logger()
self.__exp_io = exp_io
self.__samples_io = samples_io
self.__emb_io = emb_io
self.__clear_model_root_before_experiment = prepare_model_root
self.__config = config
self.__create_network_func = create_network_func
Expand Down Expand Up @@ -77,30 +80,42 @@ def __prepare_model(self):
# Notify other subscribers that initialization process has been completed.
self.__config.init_initializers()

def __check_targets_existed(self, data_types_iter, data_folding):
""" Check that all the required resources existed.
"""

if not self.__samples_io.check_targets_existed(data_types_iter=data_types_iter, data_folding=data_folding):
return False

if not self.__emb_io.check_targets_existed(data_folding=data_folding):
return False

return True

def __handle_iteration(self, data_folding, data_type):
assert(isinstance(data_folding, BaseDataFolding))
assert(isinstance(data_type, DataType))

targets_existed = self.__exp_io.check_targets_existed(
targets_existed = self.__samples_io.check_targets_existed(
data_types_iter=data_folding.iter_supported_data_types(),
data_folding=data_folding)

if not targets_existed:
raise Exception("Data has not been initialized/serialized!")

# Reading embedding.
embedding_data = self.__exp_io.load_embedding(data_folding)
embedding_data = self.__emb_io.load_embedding(data_folding)
self.__config.set_term_embedding(embedding_data)

# Performing samples reading process.
inference_ctx = InferenceContext.create_empty()
inference_ctx.initialize(
dtypes=data_folding.iter_supported_data_types(),
create_samples_view_func=lambda data_type: self.__exp_io.create_samples_view(
create_samples_view_func=lambda data_type: self.__samples_io.create_samples_view(
data_type=data_type, data_folding=data_folding),
has_model_predefined_state=self.__model_io.IsPretrainedStateProvided,
labels_count=self.__labels_count,
vocab=self.__exp_io.load_vocab(data_folding),
vocab=self.__emb_io.load_vocab(data_folding),
bags_collection_type=self.__bags_collection_type,
input_shapes=NetworkInputShapes(iter_pairs=[
(NetworkInputShapes.FRAMES_PER_CONTEXT, self.__config.FramesPerContext),
Expand Down
128 changes: 0 additions & 128 deletions arekit/contrib/utils/io_utils/bert.py

This file was deleted.

Loading

0 comments on commit c4f942e

Please sign in to comment.