Skip to content

Commit

Permalink
#378 allows modify extension
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Jul 29, 2022
1 parent ac46675 commit 8abc84b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
19 changes: 11 additions & 8 deletions arekit/contrib/utils/model_io/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def create_opinions_writer_target(self, data_type, data_folding):
def create_samples_writer_target(self, data_type, data_folding):
return self.__get_input_sample_filepath(data_type, data_folding=data_folding)

def create_target_extension(self):
return ".tsv.gz"

def create_samples_writer(self):
return TsvWriter(write_header=True)

Expand All @@ -96,20 +99,20 @@ def create_opinions_writer(self):

def __get_input_opinions_filepath(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__get_target_dir(), template=template, prefix="opinion")
return self.__get_filepath(out_dir=self.__get_target_dir(),
template=template, prefix="opinion", extension=self.create_target_extension())

def __get_input_sample_filepath(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self.__get_filepath(out_dir=self.__get_target_dir(), template=template, prefix="sample")
return self.__get_filepath(out_dir=self.__get_target_dir(),
template=template, prefix="sample", extension=self.create_target_extension())

@staticmethod
def __get_filepath(out_dir, template, prefix):
def __get_filepath(out_dir, template, prefix, extension):
assert(isinstance(template, str))
assert(isinstance(prefix, str))
return join(out_dir, DefaultBertIOUtils.__generate_tsv_archive_filename(template=template, prefix=prefix))

@staticmethod
def __generate_tsv_archive_filename(template, prefix):
return "{prefix}-{template}.tsv.gz".format(prefix=prefix, template=template)
assert(isinstance(extension, str))
return join(out_dir, "{prefix}-{template}{extension}".format(
prefix=prefix, template=template, extension=extension))

# endregion
23 changes: 15 additions & 8 deletions arekit/contrib/utils/model_io/tf_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def create_opinions_writer(self):
def create_samples_writer(self):
return TsvWriter(write_header=True)

def create_target_extension(self):
return ".tsv.gz"

def create_opinions_writer_target(self, data_type, data_folding):
return self.__get_input_opinions_target(data_type, data_folding=data_folding)

Expand Down Expand Up @@ -119,11 +122,17 @@ def __get_model_parameter(self, default_value, get_value_func):

def __get_input_opinions_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self._get_filepath(out_dir=self._get_target_dir(), template=template, prefix="opinion")
return self.__get_filepath(out_dir=self._get_target_dir(),
template=template,
prefix="opinion",
extension=self.create_target_extension())

def __get_input_sample_target(self, data_type, data_folding):
template = filename_template(data_type=data_type, data_folding=data_folding)
return self._get_filepath(out_dir=self._get_target_dir(), template=template, prefix="sample")
return self.__get_filepath(out_dir=self._get_target_dir(),
template=template,
prefix="sample",
extension=self.create_target_extension())

def __get_term_embedding_target(self, data_folding):
return self.__get_default_embedding_filepath(data_folding)
Expand Down Expand Up @@ -151,10 +160,6 @@ def __get_experiment_folder_name(self):
return "{name}_{scale}l".format(name=self._exp_ctx.Name,
scale=str(self._exp_ctx.LabelsCount))

@staticmethod
def __generate_tsv_archive_filename(template, prefix):
return "{prefix}-{template}.tsv.gz".format(prefix=prefix, template=template)

@staticmethod
def __check_targets_existence(targets, logger):
assert (isinstance(targets, collections.Iterable))
Expand Down Expand Up @@ -195,9 +200,11 @@ def _get_target_dir(self):
dir=self._get_experiment_sources_dir())

@staticmethod
def _get_filepath(out_dir, template, prefix):
def __get_filepath(out_dir, template, prefix, extension):
assert(isinstance(template, str))
assert(isinstance(prefix, str))
return join(out_dir, DefaultNetworkIOUtils.__generate_tsv_archive_filename(template=template, prefix=prefix))
assert(isinstance(extension, str))
return join(out_dir, "{prefix}-{template}{extension}".format(
prefix=prefix, template=template, extension=extension))

# endregion

0 comments on commit 8abc84b

Please sign in to comment.