Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in DBFSLocalStore #3510

Merged
merged 6 commits into from Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 41 additions & 8 deletions horovod/spark/common/store.py
Expand Up @@ -532,34 +532,67 @@ def matches(cls, path):
return path.startswith(cls.FS_PREFIX)


# If `_DBFS_PREFIX_MAPPING` is not None, map `/dbfs/...` path to `{_DBFS_PREFIX_MAPPING}/...`
# This is used in testing, and this mapping only applies to `DBFSLocalStore.get_localized_path`
_DBFS_PREFIX_MAPPING = None


class DBFSLocalStore(FilesystemStore):
"""Uses Databricks File System (DBFS) local file APIs as a store of intermediate data and
training artifacts.

Initialized from a `prefix_path` starts with `/dbfs/...`, `file:///dbfs/...` or `dbfs:/...`, see
Initialized from a `prefix_path` starts with `/dbfs/...`, `file:///dbfs/...`, `file:/dbfs/...`
or `dbfs:/...`, see
https://docs.databricks.com/data/databricks-file-system.html#local-file-apis.
"""

DBFS_PATH_FORMAT_ERROR = "The provided path is not a DBFS path: {}, Please provide a path " \
"starting with `/dbfs/...` or `dbfs:/...` or `file:/dbfs/...` or " \
"`file:///dbfs/...`."

def __init__(self, prefix_path, *args, **kwargs):
prefix_path = self.normalize_path(prefix_path)
if not prefix_path.startswith("/dbfs/"):
warnings.warn("The provided prefix_path might be ephemeral: {} Please provide a "
"`prefix_path` starting with `/dbfs/...`".format(prefix_path))
if not DBFSLocalStore.matches_dbfs(prefix_path):
raise ValueError(DBFSLocalStore.DBFS_PATH_FORMAT_ERROR.format(prefix_path))
super(DBFSLocalStore, self).__init__(prefix_path, *args, **kwargs)

@classmethod
def matches_dbfs(cls, path):
return path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("file:///dbfs/")
return (path.startswith("dbfs:/") and not path.startswith("dbfs://")) or \
path.startswith("/dbfs/") or \
path.startswith("file:///dbfs/") or \
path.startswith("file:/dbfs/")

@staticmethod
def normalize_path(path):
WeichenXu123 marked this conversation as resolved.
Show resolved Hide resolved
"""
Normalize the path to the form `/dbfs/...`
"""
if path.startswith("dbfs:/"):
if path.startswith("dbfs:/") and not path.startswith("dbfs://"):
return "/dbfs" + path[5:]
if path.startswith("file:///dbfs/"):
elif path.startswith("/dbfs/"):
return path
elif path.startswith("file:///dbfs/"):
return path[7:]
return path
elif path.startswith("file:/dbfs/"):
return path[5:]
else:
raise ValueError(DBFSLocalStore.DBFS_PATH_FORMAT_ERROR.format(path))

def exists(self, path):
localized_path = self.get_localized_path(path)
return self.fs.exists(localized_path)

def get_localized_path(self, path):
local_path = DBFSLocalStore.normalize_path(path)
if _DBFS_PREFIX_MAPPING:
# this is for testing.
return os.path.join(_DBFS_PREFIX_MAPPING, path[6:])
else:
return local_path

def get_full_path(self, path):
return "file://" + DBFSLocalStore.normalize_path(path)

def get_checkpoint_filename(self):
# Use the default Tensorflow SavedModel format in TF 2.x. In TF 1.x, the SavedModel format
Expand Down
28 changes: 23 additions & 5 deletions test/integration/test_spark.py
Expand Up @@ -15,6 +15,9 @@

import contextlib
import copy
import shutil
import uuid

from horovod.spark.common.store import FilesystemStore
import io
import itertools
Expand Down Expand Up @@ -1717,6 +1720,7 @@ def test_to_list(self):
with pytest.raises(ValueError):
util.to_list(['item1', 'item2'], 4)

@mock.patch("horovod.spark.common.store._DBFS_PREFIX_MAPPING", "/tmp")
def test_dbfs_local_store(self):
import h5py
import io
Expand All @@ -1738,13 +1742,22 @@ def test_dbfs_local_store(self):
assert isinstance(dbfs_local_store, DBFSLocalStore)
dbfs_local_store = Store.create("file:///dbfs/tmp/test_local_dir3")
assert isinstance(dbfs_local_store, DBFSLocalStore)
dbfs_local_store = Store.create("file:/dbfs/tmp/test_local_dir3")
assert isinstance(dbfs_local_store, DBFSLocalStore)
assert not DBFSLocalStore.matches_dbfs("dbfs://tmp/test_local_dir3")
finally:
if "DATABRICKS_RUNTIME_VERSION" in os.environ:
del os.environ["DATABRICKS_RUNTIME_VERSION"]

assert DBFSLocalStore.normalize_path("file:/dbfs/tmp/a1") == "/dbfs/tmp/a1"
assert DBFSLocalStore.normalize_path("file:///dbfs/tmp/a1") == "/dbfs/tmp/a1"
assert DBFSLocalStore.normalize_path("/dbfs/tmp/a1") == "/dbfs/tmp/a1"
assert DBFSLocalStore.normalize_path("dbfs:/tmp/a1") == "/dbfs/tmp/a1"
with pytest.raises(ValueError):
DBFSLocalStore.normalize_path("dbfs://tmp/a1")

# test get_checkpoint_filename suffix
# Use a tmp path for testing.
dbfs_store = DBFSLocalStore("/tmp/test_dbfs_dir")
dbfs_store = DBFSLocalStore("/dbfs/test_dbfs_dir")
dbfs_ckpt_name = dbfs_store.get_checkpoint_filename()
assert dbfs_ckpt_name.endswith(".tf")

Expand Down Expand Up @@ -1772,8 +1785,12 @@ def deserialize_keras_model(serialized_model):
assert reconstructed_model_dbfs.get_config() == model.get_config()

# test local_store.read_serialized_keras_model
with tempdir() as tmp:
local_store = Store.create(tmp)
tmp_dir = "tmp_" + uuid.uuid4().hex
# The dbfs_dir "/dbfs/tmp_xxx" will be mapped to "/tmp/tmp_xxx" in testing
dbfs_dir = os.path.join("/dbfs", tmp_dir)
actual_dbfs_dir = os.path.join("/tmp", tmp_dir)
try:
local_store = Store.create(dbfs_dir)
get_local_output_dir = local_store.get_local_output_dir_fn("0")
with get_local_output_dir() as run_output_dir:
local_ckpt_path = run_output_dir + "/" + local_store.get_checkpoint_filename()
Expand All @@ -1786,7 +1803,8 @@ def deserialize_keras_model(serialized_model):
reconstructed_model_local = deserialize_keras_model(serialized_model_local)
if LooseVersion(tensorflow.__version__) >= LooseVersion("2.3.0"):
assert reconstructed_model_local.get_config() == model.get_config()

finally:
shutil.rmtree(actual_dbfs_dir, ignore_errors=True)

def test_output_df_schema(self):
label_cols = ['y1', 'y_embedding']
Expand Down