Skip to content

Commit

Permalink
#282 related fix. Removed non utlized method. (#373 related.)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 30, 2022
1 parent 02a423d commit 3d3280b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
7 changes: 0 additions & 7 deletions arekit/common/experiment/api/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
from arekit.common.experiment.api.ctx_base import ExperimentContext


class BaseIOUtils(object):
""" Represents base experiment utils for input/output for:
samples -- data that utilized for experiments;
results -- evaluation of experiments.
"""

def __init__(self, exp_ctx):
assert(isinstance(exp_ctx, ExperimentContext))
self._exp_ctx = exp_ctx

# region abstract methods

def get_target_dir(self):
Expand Down
4 changes: 0 additions & 4 deletions arekit/contrib/experiment_rusentrel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def create_annotated_collection_target(doc_id, data_type, target_dir, labels_cou
return target


def __get_experiment_folder_name(self):
return "{name}_{scale}l".format(name=self._exp_ctx.Name, scale=str(self._exp_ctx.LabelsCount))


def __get_annotator_dir(target_dir, labels_count):
return join_dir_with_subfolder_name(dir=target_dir, subfolder_name=__get_annotator_name(labels_count))

Expand Down
9 changes: 7 additions & 2 deletions arekit/contrib/utils/model_io/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.contrib.utils.data.views.opinions import BaseOpinionStorageView
from arekit.contrib.utils.model_io.utils import join_dir_with_subfolder_name, filename_template
Expand All @@ -19,6 +20,10 @@ class DefaultBertIOUtils(BaseIOUtils):
for BERT-related data preparation.
"""

def __init__(self, exp_ctx):
assert(isinstance(exp_ctx, ExperimentContext))
self.__exp_ctx = exp_ctx

def _get_experiment_sources_dir(self):
""" Provides directory for samples.
"""
Expand Down Expand Up @@ -58,10 +63,10 @@ def __get_target_dir(self):
subfolder_name=self.__get_experiment_folder_name(),
dir=self._get_experiment_sources_dir())

return join(default_dir, self._exp_ctx.ModelIO.get_model_name())
return join(default_dir, self.__exp_ctx.ModelIO.get_model_name())

def __get_experiment_folder_name(self):
return "{name}".format(name=self._exp_ctx.Name)
return "{name}".format(name=self.__exp_ctx.Name)

# endregion

Expand Down
11 changes: 8 additions & 3 deletions arekit/contrib/utils/model_io/tf_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from arekit.common.data.row_ids.multiple import MultipleIDProvider
from arekit.common.data.storages.base import BaseRowsStorage
from arekit.common.data.views.samples import BaseSampleStorageView
from arekit.common.experiment.api.ctx_base import ExperimentContext
from arekit.common.experiment.api.io_utils import BaseIOUtils
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.base import BaseDataFolding
Expand Down Expand Up @@ -33,6 +34,10 @@ class DefaultNetworkIOUtils(BaseIOUtils):
TERM_EMBEDDING_FILENAME_TEMPLATE = 'term_embedding-{cv_index}'
VOCABULARY_FILENAME_TEMPLATE = "vocab-{cv_index}.txt"

def __init__(self, exp_ctx):
assert(isinstance(exp_ctx, ExperimentContext))
self.__exp_ctx = exp_ctx

# region public methods

def get_target_dir(self):
Expand Down Expand Up @@ -86,7 +91,7 @@ def load_embedding(self, data_folding):
return NpzEmbeddingHelper.load_embedding(source)

def has_model_predefined_state(self):
model_io = self._exp_ctx.ModelIO
model_io = self.__exp_ctx.ModelIO
return self.__model_is_pretrained_state_provided(model_io)

def check_targets_existed(self, data_types_iter, data_folding):
Expand All @@ -110,7 +115,7 @@ def __get_model_parameter(self, default_value, get_value_func):
assert(default_value is not None)
assert(callable(get_value_func))

model_io = self._exp_ctx.ModelIO
model_io = self.__exp_ctx.ModelIO

if model_io is None:
return default_value
Expand Down Expand Up @@ -157,7 +162,7 @@ def __get_term_embedding_source(self, data_folding):
get_value_func=lambda model_io: model_io.get_model_embedding_filepath())

def __get_experiment_folder_name(self):
return "{name}".format(name=self._exp_ctx.Name)
return "{name}".format(name=self.__exp_ctx.Name)

@staticmethod
def __check_targets_existence(targets, logger):
Expand Down

0 comments on commit 3d3280b

Please sign in to comment.