From c2fd9d7a7f83f5c035fe6c77e0010e9c4568c057 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Sun, 31 Jul 2022 14:13:31 +0300 Subject: [PATCH] #282 related refactoring. --- .../networks/pipelines/items/training.py | 2 +- arekit/contrib/utils/model_io/bert.py | 14 +++-- arekit/contrib/utils/model_io/tf_networks.py | 59 ++++++++++--------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/arekit/contrib/networks/pipelines/items/training.py b/arekit/contrib/networks/pipelines/items/training.py index 961193b1..64a626a3 100644 --- a/arekit/contrib/networks/pipelines/items/training.py +++ b/arekit/contrib/networks/pipelines/items/training.py @@ -99,7 +99,7 @@ def __handle_iteration(self, data_folding, data_type): dtypes=data_folding.iter_supported_data_types(), create_samples_view_func=lambda data_type: self.__exp_io.create_samples_view( data_type=data_type, data_folding=data_folding), - has_model_predefined_state=self.__exp_io.has_model_predefined_state(), + has_model_predefined_state=self.__model_io.IsPretrainedStateProvided, labels_count=self.__labels_count, vocab=self.__exp_io.load_vocab(data_folding), bags_collection_type=self.__bags_collection_type, diff --git a/arekit/contrib/utils/model_io/bert.py b/arekit/contrib/utils/model_io/bert.py index e7ff24af..ff69d94e 100644 --- a/arekit/contrib/utils/model_io/bert.py +++ b/arekit/contrib/utils/model_io/bert.py @@ -1,6 +1,7 @@ import logging from os.path import join, exists +from arekit.common.data.input.writers.base import BaseWriter from arekit.common.data.input.writers.tsv import TsvWriter from arekit.common.data.row_ids.multiple import MultipleIDProvider from arekit.common.data.storages.base import BaseRowsStorage @@ -20,9 +21,14 @@ class DefaultBertIOUtils(BaseIOUtils): for BERT-related data preparation. """ - def __init__(self, exp_ctx): + def __init__(self, exp_ctx, + samples_writer=TsvWriter(write_header=True), + target_extension=".tsv.gz"): assert(isinstance(exp_ctx, ExperimentContext)) + assert(isinstance(samples_writer, BaseWriter)) self.__exp_ctx = exp_ctx + self.__samples_writer = samples_writer + self.__target_extension = target_extension def _get_experiment_sources_dir(self): """ Provides directory for samples. @@ -89,13 +95,13 @@ def create_samples_writer_target(self, data_type, data_folding): return self.__get_input_sample_filepath(data_type, data_folding=data_folding) def create_target_extension(self): - return ".tsv.gz" + return self.__target_extension def create_samples_writer(self): - return TsvWriter(write_header=True) + return self.__samples_writer def create_opinions_writer(self): - return TsvWriter(write_header=False) + return self.__samples_writer # endregion diff --git a/arekit/contrib/utils/model_io/tf_networks.py b/arekit/contrib/utils/model_io/tf_networks.py index 38c6e128..f4094713 100644 --- a/arekit/contrib/utils/model_io/tf_networks.py +++ b/arekit/contrib/utils/model_io/tf_networks.py @@ -2,6 +2,7 @@ import logging from os.path import join, exists +from arekit.common.data.input.writers.base import BaseWriter from arekit.common.data.input.writers.tsv import TsvWriter from arekit.common.data.row_ids.multiple import MultipleIDProvider from arekit.common.data.storages.base import BaseRowsStorage @@ -10,7 +11,6 @@ from arekit.common.experiment.api.io_utils import BaseIOUtils from arekit.common.experiment.data_type import DataType from arekit.common.folding.base import BaseDataFolding -from arekit.contrib.networks.core.model_io import NeuralNetworkModelIO from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView from arekit.contrib.utils.model_io.utils import join_dir_with_subfolder_name, filename_template from arekit.contrib.utils.utils_folding import experiment_iter_index @@ -34,9 +34,14 @@ class DefaultNetworkIOUtils(BaseIOUtils): TERM_EMBEDDING_FILENAME_TEMPLATE = 'term_embedding-{cv_index}' VOCABULARY_FILENAME_TEMPLATE = "vocab-{cv_index}.txt" - def __init__(self, exp_ctx): + def __init__(self, exp_ctx, + samples_writer=TsvWriter(write_header=True), + target_extension=".tsv.gz"): assert(isinstance(exp_ctx, ExperimentContext)) + assert(isinstance(samples_writer, BaseWriter)) self.__exp_ctx = exp_ctx + self.__samples_writer = samples_writer + self.__target_extension = target_extension # region public methods @@ -58,13 +63,13 @@ def create_opinions_view(self, target): return BaseOpinionStorageView(storage) def create_opinions_writer(self): - return TsvWriter(write_header=False) + return self.__samples_writer def create_samples_writer(self): - return TsvWriter(write_header=True) + return self.__samples_writer def create_target_extension(self): - return ".tsv.gz" + return self.__target_extension def create_opinions_writer_target(self, data_type, data_folding): return self.__get_input_opinions_target(data_type, data_folding=data_folding) @@ -72,6 +77,8 @@ def create_opinions_writer_target(self, data_type, data_folding): def create_samples_writer_target(self, data_type, data_folding): return self.__get_input_sample_target(data_type, data_folding=data_folding) + # region Embedding-related data + def save_vocab(self, data, data_folding): assert(isinstance(data_folding, BaseDataFolding)) target = self.__get_default_vocab_filepath(data_folding) @@ -90,9 +97,7 @@ def load_embedding(self, data_folding): source = self.__get_term_embedding_source(data_folding) return NpzEmbeddingHelper.load_embedding(source) - def has_model_predefined_state(self): - model_io = self.__exp_ctx.ModelIO - return self.__model_is_pretrained_state_provided(model_io) + # endregion def check_targets_existed(self, data_types_iter, data_folding): for data_type in data_types_iter: @@ -109,22 +114,6 @@ def check_targets_existed(self, data_types_iter, data_folding): # endregion - # region private methods - - def __get_model_parameter(self, default_value, get_value_func): - assert(default_value is not None) - assert(callable(get_value_func)) - - model_io = self.__exp_ctx.ModelIO - - if model_io is None: - return default_value - - predefined_value = get_value_func(model_io) if \ - self.__model_is_pretrained_state_provided(model_io) else None - - return default_value if predefined_value is None else predefined_value - def __get_input_opinions_target(self, data_type, data_folding): template = filename_template(data_type=data_type, data_folding=data_folding) return self.__get_filepath(out_dir=self._get_target_dir(), @@ -139,14 +128,11 @@ def __get_input_sample_target(self, data_type, data_folding): prefix="sample", extension=self.create_target_extension()) + # region embedding-related data + def __get_term_embedding_target(self, data_folding): return self.__get_default_embedding_filepath(data_folding) - @staticmethod - def __model_is_pretrained_state_provided(model_io): - assert(isinstance(model_io, NeuralNetworkModelIO)) - return model_io.IsPretrainedStateProvided - def ___get_vocab_source(self, data_folding): """ It is possible to load a predefined embedding from another experiment using the related filepath provided by model_io. @@ -161,6 +147,21 @@ def __get_term_embedding_source(self, data_folding): return self.__get_model_parameter(default_value=self.__get_default_embedding_filepath(data_folding), get_value_func=lambda model_io: model_io.get_model_embedding_filepath()) + def __get_model_parameter(self, default_value, get_value_func): + assert(default_value is not None) + assert(callable(get_value_func)) + + model_io = self.__exp_ctx.ModelIO + + if model_io is None: + return default_value + + predefined_value = get_value_func(model_io) if model_io.IsPretrainedStateProvided else None + + return default_value if predefined_value is None else predefined_value + + # end region + def __get_experiment_folder_name(self): return "{name}".format(name=self.__exp_ctx.Name)