Skip to content

Commit

Permalink
#282 simplified passed parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 31, 2022
1 parent c4f942e commit ca1e647
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 38 deletions.
24 changes: 8 additions & 16 deletions arekit/contrib/utils/io_utils/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.folding.base import BaseDataFolding
from arekit.contrib.utils.io_utils.utils import join_dir_with_subfolder_name, check_targets_existence
from arekit.contrib.utils.io_utils.utils import check_targets_existence
from arekit.contrib.utils.np_utils.embedding import NpzEmbeddingHelper
from arekit.contrib.utils.utils_folding import experiment_iter_index

Expand All @@ -19,7 +19,9 @@ class NpzEmbeddingIOUtils(BaseIOUtils):
TERM_EMBEDDING_FILENAME_TEMPLATE = 'term_embedding-{cv_index}'
VOCABULARY_FILENAME_TEMPLATE = "vocab-{cv_index}.txt"

def __init__(self, exp_ctx):
def __init__(self, target_dir, exp_ctx):
assert(isinstance(target_dir, str))
self.__target_dir = target_dir
self.__exp_ctx = exp_ctx

# region Embedding-related data
Expand All @@ -43,28 +45,18 @@ def load_embedding(self, data_folding):
return NpzEmbeddingHelper.load_embedding(source)

def check_targets_existed(self, data_folding):
filepaths = [
targets = [
self.__get_default_vocab_filepath(data_folding=data_folding),
self.__get_term_embedding_target(data_folding=data_folding)
]
return check_targets_existence(targets=filepaths)
return check_targets_existence(targets=targets)

# endregion

# region embedding-related data

def _get_experiment_sources_dir(self):
raise NotImplementedError()

def _get_target_dir(self):
""" Represents an experiment dir of specific label scale format,
defined by labels scaler.
"""
return join_dir_with_subfolder_name(subfolder_name=self.__exp_ctx.Name,
dir=self._get_experiment_sources_dir())

def __get_default_embedding_filepath(self, data_folding):
return join(self._get_target_dir(),
return join(self.__target_dir,
self.TERM_EMBEDDING_FILENAME_TEMPLATE.format(
cv_index=experiment_iter_index(data_folding)) + '.npz')

Expand Down Expand Up @@ -99,7 +91,7 @@ def __get_model_parameter(self, default_value, get_value_func):
return default_value if predefined_value is None else predefined_value

def __get_default_vocab_filepath(self, data_folding):
return join(self._get_target_dir(),
return join(self.__target_dir,
self.VOCABULARY_FILENAME_TEMPLATE.format(
cv_index=experiment_iter_index(data_folding)) + '.npz')

Expand Down
27 changes: 5 additions & 22 deletions arekit/contrib/utils/io_utils/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.data_type import DataType
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.io_utils.utils import join_dir_with_subfolder_name, filename_template, check_targets_existence
from arekit.contrib.utils.io_utils.utils import filename_template, check_targets_existence

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
Expand All @@ -23,20 +22,16 @@ class SamplesIOUtils(BaseIOUtils):
Samples required for machine learning training/inferring.
"""

def __init__(self, exp_ctx,
def __init__(self, target_dir,
samples_writer=TsvWriter(write_header=True),
target_extension=".tsv.gz"):
assert(isinstance(exp_ctx, ExperimentContext))
assert(isinstance(samples_writer, BaseWriter))
self.__exp_ctx = exp_ctx
self.__target_dir = target_dir
self.__samples_writer = samples_writer
self.__target_extension = target_extension

# region public methods

def get_target_dir(self):
return self._get_target_dir()

def create_samples_view(self, data_type, data_folding):
assert(isinstance(data_type, DataType))
storage = BaseRowsStorage.from_tsv(
Expand Down Expand Up @@ -79,14 +74,14 @@ def check_targets_existed(self, data_types_iter, data_folding):

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self._get_target_dir(),
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="opinion",
extension=self.create_target_extension())

def __get_input_sample_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self._get_target_dir(),
return self.__get_filepath(out_dir=self.__target_dir,
template=template,
prefix="sample",
extension=self.create_target_extension())
Expand All @@ -95,18 +90,6 @@ def __get_input_sample_target(self, data_type, data_folding):

# region protected methods

def _get_experiment_sources_dir(self):
""" Provides directory for samples.
"""
raise NotImplementedError()

def _get_target_dir(self):
""" Represents an experiment dir of specific label scale format,
defined by labels scaler.
"""
return join_dir_with_subfolder_name(subfolder_name=self.__exp_ctx.Name,
dir=self._get_experiment_sources_dir())

@staticmethod
def __get_filepath(out_dir, template, prefix, extension):
assert(isinstance(template, str))
Expand Down

0 comments on commit ca1e647

Please sign in to comment.