Skip to content

Commit

Permalink
Refactored LocalStore into FilesystemStore
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Addair <tgaddair@gmail.com>
  • Loading branch information
tgaddair committed May 21, 2021
1 parent c38e32e commit bdaa5f2
Showing 1 changed file with 77 additions and 87 deletions.
164 changes: 77 additions & 87 deletions horovod/spark/common/store.py
Expand Up @@ -16,6 +16,7 @@
import contextlib
import errno
import os
import pathlib
import re
import shutil
import tempfile
Expand Down Expand Up @@ -157,30 +158,27 @@ def create(prefix_path, *args, **kwargs):
return HDFSStore(prefix_path, *args, **kwargs)
elif is_databricks() and DBFSLocalStore.matches_dbfs(prefix_path):
return DBFSLocalStore(prefix_path, *args, **kwargs)
elif LocalStore.matches(prefix_path):
return LocalStore(prefix_path, *args, **kwargs)
else:
return FilesystemStore(prefix_path, *args, **kwargs)


class FilesystemStore(Store):
class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""

def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None, runs_path=None, save_runs=True):
self.fs, self.protocol = FilesystemStore._get_fs_and_protocol(prefix_path)
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
super(FilesystemStore, self).__init__()
super().__init__()

def exists(self, path):
return self.get_filesystem().exists(self.get_localized_path(path))
return self.fs.exists(self.get_localized_path(path))

def read(self, path):
with self.get_filesystem().open(self.get_localized_path(path), 'rb') as f:
with self.fs.open(self.get_localized_path(path), 'rb') as f:
return f.read()

def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
Expand Down Expand Up @@ -256,23 +254,6 @@ def get_checkpoint_filename(self):
def get_logs_subdir(self):
return 'logs'

def get_full_path(self, path):
if not self.matches(path):
return self.path_prefix() + path
return path

def get_localized_path(self, path):
if self.matches(path):
return path[len(self.path_prefix()):]
return path

def get_full_path_fn(self):
prefix = self.path_prefix()

def get_path(path):
return prefix + path
return get_path

def _get_full_path_or_default(self, path, default_key):
if path is not None:
return self.get_full_path(path)
Expand All @@ -281,69 +262,83 @@ def _get_full_path_or_default(self, path, default_key):
def _get_path(self, key):
return os.path.join(self.prefix_path, key)

def path_prefix(self):
def get_local_output_dir_fn(self, run_id):
@contextlib.contextmanager
def local_run_path():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
return local_run_path

def get_localized_path(self, path):
raise NotImplementedError()

@classmethod
def matches(cls, path):
return True
def get_full_path(self, path):
raise NotImplementedError()

@staticmethod
def _get_fs_and_protocol(url):
protocol, path = split_protocol(url)
fs = fsspec.filesystem(protocol)
return fs, protocol
def get_full_path_fn(self):
raise NotImplementedError()

@property
def fs(self):
raise NotImplementedError()

class LocalStore(FilesystemStore):
"""Uses the local filesystem as a store of intermediate data and training artifacts."""

FS_PREFIX = 'file://'
class FilesystemStore(AbstractFilesystemStore):
"""Concrete filesystems store that delegates to `fsspec`."""

def __init__(self, prefix_path, *args, **kwargs):
self._fs = pa.LocalFileSystem()
super(LocalStore, self).__init__(prefix_path, *args, **kwargs)
self._fs, self.protocol = FilesystemStore._get_fs_and_protocol(prefix_path)
super().__init__(prefix_path, *args, **kwargs)

def path_prefix(self):
return self.FS_PREFIX
def sync_fn(self, run_id):
run_path = self.get_run_path(run_id)

def get_filesystem(self):
return self._fs
def fn(local_run_path):
self.fs.put(local_run_path, run_path, recursive=True)

def get_local_output_dir_fn(self, run_id):
run_path = self.get_localized_path(self.get_run_path(run_id))
return fn

@contextlib.contextmanager
def local_run_path():
if not os.path.exists(run_path):
try:
os.makedirs(run_path, mode=0o755)
except OSError as e:
# Race condition from workers on the same host: ignore
if e.errno != errno.EEXIST:
raise
yield run_path
def get_localized_path(self, path):
_, lpath = split_protocol(path)
return lpath

return local_run_path
def get_full_path(self, path):
return self.get_full_path_fn()(path)

def sync_fn(self, run_id):
run_path = self.get_localized_path(self.get_run_path(run_id))
def get_full_path_fn(self):
def get_path(path):
protocol, _ = split_protocol(path)
if protocol is not None:
return path
return pathlib.Path(os.path.abspath(path)).as_uri()
return get_path

def fn(local_run_path):
# No-op for LocalStore since the `local_run_path` will be the same as the run path
assert run_path == local_run_path
return fn
@property
def fs(self):
return self._fs

@staticmethod
def _get_fs_and_protocol(url):
protocol, path = split_protocol(url)
fs = fsspec.filesystem(protocol)
return fs, protocol

@classmethod
def matches(cls, path):
return path.startswith(cls.FS_PREFIX)
return True

@classmethod
def filesystem_prefix(cls):
return cls.FS_PREFIX

class LocalStore(FilesystemStore):
"""Uses the local filesystem as a store of intermediate data and training artifacts.
This class is deprecated and now just resolves to FilesystemStore.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class HDFSStore(FilesystemStore):
class HDFSStore(AbstractFilesystemStore):
"""Uses HDFS as a store of intermediate data and training artifacts.
Initialized from a `prefix_path` that can take one of the following forms:
Expand All @@ -367,9 +362,7 @@ class HDFSStore(FilesystemStore):

def __init__(self, prefix_path,
host=None, port=None, user=None, kerb_ticket=None,
driver='libhdfs', extra_conf=None, temp_dir=None, *args, **kwargs):
self._temp_dir = temp_dir

driver='libhdfs', extra_conf=None, *args, **kwargs):
prefix, url_host, url_port, path, path_offset = self.parse_url(prefix_path)
self._check_url(prefix_path, prefix, path)
self._url_prefix = prefix_path[:path_offset] if prefix else self.FS_PREFIX
Expand Down Expand Up @@ -400,24 +393,21 @@ def parse_url(self, url):
path_offset = match.start(4)
return prefix, host, port, path, path_offset

def path_prefix(self):
return self._url_prefix

def get_filesystem(self):
return self._hdfs
def get_full_path(self, path):
if not self.matches(path):
return self._url_prefix + path
return path

def get_local_output_dir_fn(self, run_id):
temp_dir = self._temp_dir
def get_full_path_fn(self):
prefix = self._url_prefix

@contextlib.contextmanager
def local_run_path():
dirpath = tempfile.mkdtemp(dir=temp_dir)
try:
yield dirpath
finally:
shutil.rmtree(dirpath)
def get_path(path):
return prefix + path
return get_path

return local_run_path
@property
def fs(self):
return self._hdfs

def sync_fn(self, run_id):
class SyncState(object):
Expand Down Expand Up @@ -475,11 +465,11 @@ def _check_url(self, url, prefix, path):
raise ValueError('Failed to parse path from URL: {}'.format(url))

@classmethod
def filesystem_prefix(cls):
return cls.FS_PREFIX
def matches(cls, path):
return path.startswith(cls.FS_PREFIX)


class DBFSLocalStore(LocalStore):
class DBFSLocalStore(FilesystemStore):
"""Uses Databricks File System (DBFS) local file APIs as a store of intermediate data and
training artifacts.
Expand Down

0 comments on commit bdaa5f2

Please sign in to comment.