Skip to content

Commit

Permalink
#282 related refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 31, 2022
1 parent 3b564ba commit c2fd9d7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion arekit/contrib/networks/pipelines/items/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __handle_iteration(self, data_folding, data_type):
dtypes=data_folding.iter_supported_data_types(),
create_samples_view_func=lambda data_type: self.__exp_io.create_samples_view(
data_type=data_type, data_folding=data_folding),
has_model_predefined_state=self.__exp_io.has_model_predefined_state(),
has_model_predefined_state=self.__model_io.IsPretrainedStateProvided,
labels_count=self.__labels_count,
vocab=self.__exp_io.load_vocab(data_folding),
bags_collection_type=self.__bags_collection_type,
Expand Down
14 changes: 10 additions & 4 deletions arekit/contrib/utils/model_io/bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from os.path import join, exists

from arekit.common.data.input.writers.base import BaseWriter
from arekit.common.data.input.writers.tsv import TsvWriter
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
Expand All @@ -20,9 +21,14 @@ class DefaultBertIOUtils(BaseIOUtils):
for BERT-related data preparation.
"""

def __init__(self, exp_ctx):
def __init__(self, exp_ctx,
samples_writer=TsvWriter(write_header=True),
target_extension=".tsv.gz"):
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(samples_writer, BaseWriter))
self.__exp_ctx = exp_ctx
self.__samples_writer = samples_writer
self.__target_extension = target_extension

def _get_experiment_sources_dir(self):
""" Provides directory for samples.
Expand Down Expand Up @@ -89,13 +95,13 @@ def create_samples_writer_target(self, data_type, data_folding):
return self.__get_input_sample_filepath(data_type, data_folding=data_folding)

def create_target_extension(self):
return ".tsv.gz"
return self.__target_extension

def create_samples_writer(self):
return TsvWriter(write_header=True)
return self.__samples_writer

def create_opinions_writer(self):
return TsvWriter(write_header=False)
return self.__samples_writer

# endregion

Expand Down
59 changes: 30 additions & 29 deletions arekit/contrib/utils/model_io/tf_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from os.path import join, exists

from arekit.common.data.input.writers.base import BaseWriter
from arekit.common.data.input.writers.tsv import TsvWriter
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
Expand All @@ -10,7 +11,6 @@
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
from arekit.contrib.networks.core.model_io import NeuralNetworkModelIO
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.model_io.utils import join_dir_with_subfolder_name, filename_template
from arekit.contrib.utils.utils_folding import experiment_iter_index
Expand All @@ -34,9 +34,14 @@ class DefaultNetworkIOUtils(BaseIOUtils):
TERM_EMBEDDING_FILENAME_TEMPLATE = 'term_embedding-{cv_index}'
VOCABULARY_FILENAME_TEMPLATE = "vocab-{cv_index}.txt"

def __init__(self, exp_ctx):
def __init__(self, exp_ctx,
samples_writer=TsvWriter(write_header=True),
target_extension=".tsv.gz"):
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(samples_writer, BaseWriter))
self.__exp_ctx = exp_ctx
self.__samples_writer = samples_writer
self.__target_extension = target_extension

# region public methods

Expand All @@ -58,20 +63,22 @@ def create_opinions_view(self, target):
return BaseOpinionStorageView(storage)

def create_opinions_writer(self):
return TsvWriter(write_header=False)
return self.__samples_writer

def create_samples_writer(self):
return TsvWriter(write_header=True)
return self.__samples_writer

def create_target_extension(self):
return ".tsv.gz"
return self.__target_extension

def create_opinions_writer_target(self, data_type, data_folding):
return self.__get_input_opinions_target(data_type, data_folding=data_folding)

def create_samples_writer_target(self, data_type, data_folding):
return self.__get_input_sample_target(data_type, data_folding=data_folding)

# region Embedding-related data

def save_vocab(self, data, data_folding):
assert(isinstance(data_folding, BaseDataFolding))
target = self.__get_default_vocab_filepath(data_folding)
Expand All @@ -90,9 +97,7 @@ def load_embedding(self, data_folding):
source = self.__get_term_embedding_source(data_folding)
return NpzEmbeddingHelper.load_embedding(source)

def has_model_predefined_state(self):
model_io = self.__exp_ctx.ModelIO
return self.__model_is_pretrained_state_provided(model_io)
# endregion

def check_targets_existed(self, data_types_iter, data_folding):
for data_type in data_types_iter:
Expand All @@ -109,22 +114,6 @@ def check_targets_existed(self, data_types_iter, data_folding):

# endregion

# region private methods

def __get_model_parameter(self, default_value, get_value_func):
assert(default_value is not None)
assert(callable(get_value_func))

model_io = self.__exp_ctx.ModelIO

if model_io is None:
return default_value

predefined_value = get_value_func(model_io) if \
self.__model_is_pretrained_state_provided(model_io) else None

return default_value if predefined_value is None else predefined_value

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self._get_target_dir(),
Expand All @@ -139,14 +128,11 @@ def __get_input_sample_target(self, data_type, data_folding):
prefix="sample",
extension=self.create_target_extension())

# region embedding-related data

def __get_term_embedding_target(self, data_folding):
return self.__get_default_embedding_filepath(data_folding)

@staticmethod
def __model_is_pretrained_state_provided(model_io):
assert(isinstance(model_io, NeuralNetworkModelIO))
return model_io.IsPretrainedStateProvided

def ___get_vocab_source(self, data_folding):
""" It is possible to load a predefined embedding from another experiment
using the related filepath provided by model_io.
Expand All @@ -161,6 +147,21 @@ def __get_term_embedding_source(self, data_folding):
return self.__get_model_parameter(default_value=self.__get_default_embedding_filepath(data_folding),
get_value_func=lambda model_io: model_io.get_model_embedding_filepath())

def __get_model_parameter(self, default_value, get_value_func):
assert(default_value is not None)
assert(callable(get_value_func))

model_io = self.__exp_ctx.ModelIO

if model_io is None:
return default_value

predefined_value = get_value_func(model_io) if model_io.IsPretrainedStateProvided else None

return default_value if predefined_value is None else predefined_value

# end region

def __get_experiment_folder_name(self):
return "{name}".format(name=self.__exp_ctx.Name)

Expand Down

0 comments on commit c2fd9d7

Please sign in to comment.