Skip to content

Commit

Permalink
make file name an arg
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Zhang <pengz@uber.com>
  • Loading branch information
irasit committed Feb 28, 2022
1 parent 8f9fd57 commit b0cc0fa
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions horovod/spark/common/store.py
Expand Up @@ -168,14 +168,16 @@ class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""

def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None,
runs_path=None, save_runs=True, storage_options=None, **kwargs):
runs_path=None, save_runs=True, storage_options=None, checkpoint_filename=None,
**kwargs):
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
self.storage_options = storage_options
self.checkpoint_filename = checkpoint_filename if checkpoint_filename else 'checkpoint'
super().__init__()

def exists(self, path):
Expand Down Expand Up @@ -262,8 +264,7 @@ def get_logs_path(self, run_id):
if self._save_runs else None

def get_checkpoint_filename(self):
# default the store checkpoint name to use h5 format
return 'checkpoint.h5'
return self.checkpoint_filename

def get_logs_subdir(self):
return 'logs'
Expand Down

0 comments on commit b0cc0fa

Please sign in to comment.