Skip to content

Commit

Permalink
Refactored FilesystemStore to use fsspec to support additional remote…
Browse files Browse the repository at this point in the history
… filesystems (horovod#2927)

Signed-off-by: Travis Addair <tgaddair@gmail.com>

Co-authored-by: Uma Shankar <8177685+umashankark@users.noreply.github.com>
  • Loading branch information
tgaddair and umashankark committed Jun 2, 2021
1 parent ef31a42 commit 7a69711
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 95 deletions.
3 changes: 3 additions & 0 deletions docs/mocks.py
Expand Up @@ -37,6 +37,9 @@ def _dummy():
'h5py',
'psutil',

'fsspec',
'fsspec.core',

'pyarrow',
'pyarrow.parquet',

Expand Down
183 changes: 98 additions & 85 deletions horovod/spark/common/store.py
Expand Up @@ -16,6 +16,7 @@
import contextlib
import errno
import os
import pathlib
import re
import shutil
import tempfile
Expand All @@ -26,6 +27,9 @@
import pyarrow as pa
import pyarrow.parquet as pq

import fsspec
from fsspec.core import split_protocol

from horovod.spark.common.util import is_databricks


Expand Down Expand Up @@ -155,26 +159,28 @@ def create(prefix_path, *args, **kwargs):
elif is_databricks() and DBFSLocalStore.matches_dbfs(prefix_path):
return DBFSLocalStore(prefix_path, *args, **kwargs)
else:
return LocalStore(prefix_path, *args, **kwargs)
return FilesystemStore(prefix_path, *args, **kwargs)


class FilesystemStore(Store):
class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""

def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None, runs_path=None, save_runs=True):
def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None,
runs_path=None, save_runs=True, storage_options=None, **kwargs):
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
super(FilesystemStore, self).__init__()
self.storage_options = storage_options
super().__init__()

def exists(self, path):
return self.get_filesystem().exists(self.get_localized_path(path))
return self.fs.exists(self.get_localized_path(path)) or self.fs.isdir(path)

def read(self, path):
with self.get_filesystem().open(self.get_localized_path(path), 'rb') as f:
with self.fs.open(self.get_localized_path(path), 'rb') as f:
return f.read()

def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
Expand All @@ -193,7 +199,7 @@ def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
return codec.dumps_base64(model_bytes)

def write_text(self, path, text):
with self.get_filesystem().open(self.get_localized_path(path), 'w') as f:
with self.fs.open(self.get_localized_path(path), 'w') as f:
f.write(text)

def is_parquet_dataset(self, path):
Expand All @@ -204,7 +210,7 @@ def is_parquet_dataset(self, path):
return False

def get_parquet_dataset(self, path):
return pq.ParquetDataset(self.get_localized_path(path), filesystem=self.get_filesystem())
return pq.ParquetDataset(self.get_localized_path(path), filesystem=self.fs)

def get_train_data_path(self, idx=None):
return '{}.{}'.format(self._train_path, idx) if idx is not None else self._train_path
Expand Down Expand Up @@ -237,7 +243,7 @@ def get_checkpoint_path(self, run_id):

def get_checkpoints(self, run_id, suffix='.ckpt'):
checkpoint_dir = self.get_localized_path(self.get_checkpoint_path(run_id))
filenames = self.get_filesystem().ls(checkpoint_dir)
filenames = self.fs.ls(checkpoint_dir)
return sorted([name for name in filenames if name.endswith(suffix)])

def get_logs_path(self, run_id):
Expand All @@ -250,23 +256,6 @@ def get_checkpoint_filename(self):
def get_logs_subdir(self):
return 'logs'

def get_full_path(self, path):
if not self.matches(path):
return self.path_prefix() + path
return path

def get_localized_path(self, path):
if self.matches(path):
return path[len(self.path_prefix()):]
return path

def get_full_path_fn(self):
prefix = self.path_prefix()

def get_path(path):
return prefix + path
return get_path

def _get_full_path_or_default(self, path, default_key):
if path is not None:
return self.get_full_path(path)
Expand All @@ -275,66 +264,90 @@ def _get_full_path_or_default(self, path, default_key):
def _get_path(self, key):
return os.path.join(self.prefix_path, key)

def path_prefix(self):
raise NotImplementedError()
def get_local_output_dir_fn(self, run_id):
@contextlib.contextmanager
def local_run_path():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
return local_run_path

def get_filesystem(self):
def get_localized_path(self, path):
raise NotImplementedError()

@classmethod
def matches(cls, path):
return path.startswith(cls.filesystem_prefix())
def get_full_path(self, path):
raise NotImplementedError()

@classmethod
def filesystem_prefix(cls):
def get_full_path_fn(self):
raise NotImplementedError()

@property
def fs(self):
raise NotImplementedError()

class LocalStore(FilesystemStore):
"""Uses the local filesystem as a store of intermediate data and training artifacts."""

FS_PREFIX = 'file://'
class FilesystemStore(AbstractFilesystemStore):
"""Concrete filesystems store that delegates to `fsspec`."""

def __init__(self, prefix_path, *args, **kwargs):
self._fs = pa.LocalFileSystem()
super(LocalStore, self).__init__(prefix_path, *args, **kwargs)
self.storage_options = kwargs['storage_options'] if 'storage_options' in kwargs else {}
self.prefix_path = prefix_path
self._fs, self.protocol = self._get_fs_and_protocol()
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
params = dict((k, kwargs[k]) for k in std_params if k in kwargs)
super().__init__(prefix_path, *args, **params)

def path_prefix(self):
return self.FS_PREFIX
def sync_fn(self, run_id):
run_path = self.get_run_path(run_id)

def fn(local_run_path):
self.fs.put(local_run_path, run_path, recursive=True, overwrite=True)

return fn

def get_filesystem(self):
return self._fs
return self.fs

def get_local_output_dir_fn(self, run_id):
run_path = self.get_localized_path(self.get_run_path(run_id))
def get_localized_path(self, path):
_, lpath = split_protocol(path)
return lpath

@contextlib.contextmanager
def local_run_path():
if not os.path.exists(run_path):
try:
os.makedirs(run_path, mode=0o755)
except OSError as e:
# Race condition from workers on the same host: ignore
if e.errno != errno.EEXIST:
raise
yield run_path
def get_full_path(self, path):
return self.get_full_path_fn()(path)

return local_run_path
def get_full_path_fn(self):
def get_path(path):
protocol, _ = split_protocol(path)
if protocol is not None:
return path
return pathlib.Path(os.path.abspath(path)).as_uri()
return get_path

def sync_fn(self, run_id):
run_path = self.get_localized_path(self.get_run_path(run_id))
@property
def fs(self):
return self._fs

def fn(local_run_path):
# No-op for LocalStore since the `local_run_path` will be the same as the run path
assert run_path == local_run_path
return fn
#@staticmethod
def _get_fs_and_protocol(self):
protocol, path = split_protocol(self.prefix_path)
fs = fsspec.filesystem(protocol, **self.storage_options)
return fs, protocol

@classmethod
def filesystem_prefix(cls):
return cls.FS_PREFIX
def matches(cls, path):
return True


class HDFSStore(FilesystemStore):
class LocalStore(FilesystemStore):
"""Uses the local filesystem as a store of intermediate data and training artifacts.
This class is deprecated and now just resolves to FilesystemStore.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


class HDFSStore(AbstractFilesystemStore):
"""Uses HDFS as a store of intermediate data and training artifacts.
Initialized from a `prefix_path` that can take one of the following forms:
Expand All @@ -358,9 +371,7 @@ class HDFSStore(FilesystemStore):

def __init__(self, prefix_path,
host=None, port=None, user=None, kerb_ticket=None,
driver='libhdfs', extra_conf=None, temp_dir=None, *args, **kwargs):
self._temp_dir = temp_dir

driver='libhdfs', extra_conf=None, *args, **kwargs):
prefix, url_host, url_port, path, path_offset = self.parse_url(prefix_path)
self._check_url(prefix_path, prefix, path)
self._url_prefix = prefix_path[:path_offset] if prefix else self.FS_PREFIX
Expand Down Expand Up @@ -391,24 +402,21 @@ def parse_url(self, url):
path_offset = match.start(4)
return prefix, host, port, path, path_offset

def path_prefix(self):
return self._url_prefix

def get_filesystem(self):
return self._hdfs
def get_full_path(self, path):
if not self.matches(path):
return self._url_prefix + path
return path

def get_local_output_dir_fn(self, run_id):
temp_dir = self._temp_dir
def get_full_path_fn(self):
prefix = self._url_prefix

@contextlib.contextmanager
def local_run_path():
dirpath = tempfile.mkdtemp(dir=temp_dir)
try:
yield dirpath
finally:
shutil.rmtree(dirpath)
def get_path(path):
return prefix + path
return get_path

return local_run_path
@property
def fs(self):
return self._hdfs

def sync_fn(self, run_id):
class SyncState(object):
Expand Down Expand Up @@ -464,13 +472,18 @@ def _check_url(self, url, prefix, path):

if not path:
raise ValueError('Failed to parse path from URL: {}'.format(url))

def get_localized_path(self, path):
if self.matches(path):
return path[len(self._url_prefix):]
return path

@classmethod
def filesystem_prefix(cls):
return cls.FS_PREFIX
def matches(cls, path):
return path.startswith(cls.FS_PREFIX)


class DBFSLocalStore(LocalStore):
class DBFSLocalStore(FilesystemStore):
"""Uses Databricks File System (DBFS) local file APIs as a store of intermediate data and
training artifacts.
Expand Down
2 changes: 2 additions & 0 deletions horovod/spark/keras/remote.py
Expand Up @@ -218,6 +218,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=store.storage_options,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=1,
Expand All @@ -228,6 +229,7 @@ def train(serialized_model, train_rows, val_rows, avg_row_size):
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=store.storage_options,
**reader_factory_kwargs) \
if should_validate else empty_batch_reader() as val_reader:

Expand Down
11 changes: 7 additions & 4 deletions horovod/spark/lightning/remote.py
Expand Up @@ -76,14 +76,16 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
schema_fields.append(sample_weight_col)

data_loader_cls = _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls)
set_data_loader = _set_data_loader_fn(transformation, schema_fields,
batch_size, calculate_shuffle_buffer_size,
data_loader_cls, loader_num_epochs)

# Storage
store = estimator.getStore()
remote_store = store.to_remote(run_id, dataset_idx)

set_data_loader = _set_data_loader_fn(transformation, schema_fields,
batch_size, calculate_shuffle_buffer_size,
data_loader_cls, loader_num_epochs, store)


def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
Expand Down Expand Up @@ -212,7 +214,7 @@ def on_sanity_check_end(self, trainer, model):
return [ResetCallback()]


def _set_data_loader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, data_loader_cls, num_epochs):
def _set_data_loader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, data_loader_cls, num_epochs, store):

@contextlib.contextmanager
def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True):
Expand Down Expand Up @@ -254,6 +256,7 @@ def set_data_loader(model, data_path, dataloader_attr, reader_worker_count, read
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=store.storage_options,
**reader_factory_kwargs) as reader:
def dataloader_fn():
return data_loader_cls(reader=reader, batch_size=batch_size,
Expand Down
2 changes: 2 additions & 0 deletions horovod/spark/torch/remote.py
Expand Up @@ -236,6 +236,7 @@ def save_checkpoint():
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=store.storage_options,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=None,
Expand All @@ -246,6 +247,7 @@ def save_checkpoint():
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
storage_options=store.storage_options,
**reader_factory_kwargs) \
if should_validate else empty_batch_reader() as val_reader:

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -112,7 +112,7 @@ def build_extensions(self):
pyspark_require_list = ['pyspark>=2.3.2;python_version<"3.8"',
'pyspark>=3.0.0;python_version>="3.8"']
# Pin h5py: https://github.com/h5py/h5py/issues/1732
spark_require_list = ['h5py<3', 'numpy', 'petastorm>=0.11.0', 'pyarrow>=0.15.0']
spark_require_list = ['h5py<3', 'numpy', 'petastorm>=0.11.0', 'pyarrow>=0.15.0', 'fsspec']
ray_require_list = ['ray']
pytorch_spark_require_list = pytorch_require_list + \
spark_require_list + \
Expand Down

0 comments on commit 7a69711

Please sign in to comment.