forked from horovod/horovod
/
store.py
535 lines (416 loc) · 19.7 KB
/
store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import errno
import os
import pathlib
import re
import shutil
import tempfile
import warnings
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.parquet as pq
import fsspec
from fsspec.core import split_protocol
from horovod.spark.common.util import is_databricks
class Store(object):
"""
Storage layer for intermediate files (materialized DataFrames) and training artifacts (checkpoints, logs).
Store provides an abstraction over a filesystem (e.g., local vs HDFS) or blob storage database. It provides the
basic semantics for reading and writing objects, and how to access objects with certain definitions.
The store exposes a generic interface that is not coupled to a specific DataFrame, model, or runtime. Every run
of an Estimator should result in a separate run directory containing checkpoints and logs, and every variation
in dataset should produce a separate intermediate data path.
In order to allow for caching but to prevent overuse of disk space on intermediate data, intermediate datasets
are named in a deterministic sequence. When a dataset is done being used for training, the intermediate files
can be reclaimed to free up disk space, but will not be automatically removed so that they can be reused as
needed. This is to support both parallel training processes using the same store on multiple DataFrames, as well
as iterative training using the same DataFrame on different model variations.
"""
def __init__(self):
self._train_data_to_key = {}
self._val_data_to_key = {}
def is_parquet_dataset(self, path):
"""Returns True if the path is the root of a Parquet dataset."""
raise NotImplementedError()
def get_parquet_dataset(self, path):
"""Returns a :py:class:`pyarrow.parquet.ParquetDataset` from the path."""
raise NotImplementedError()
def get_train_data_path(self, idx=None):
"""Returns the path to the training dataset."""
raise NotImplementedError()
def get_val_data_path(self, idx=None):
"""Returns the path to the validation dataset."""
raise NotImplementedError()
def get_test_data_path(self, idx=None):
"""Returns the path to the test dataset."""
raise NotImplementedError()
def saving_runs(self):
"""Returns True if run output should be saved during training."""
raise NotImplementedError()
def get_runs_path(self):
"""Returns the base path for all runs."""
raise NotImplementedError()
def get_run_path(self, run_id):
"""Returns the path to the run with the given ID."""
raise NotImplementedError()
def get_checkpoint_path(self, run_id):
"""Returns the path to the checkpoint file(s) for the given run."""
raise NotImplementedError()
def get_checkpoints(self, run_id, suffix='.ckpt'):
"""Returns a list of paths for all checkpoints saved this run."""
raise NotImplementedError()
def get_logs_path(self, run_id):
"""Returns the path to the log directory for the given run."""
raise NotImplementedError()
def get_checkpoint_filename(self):
"""Returns the basename of the saved checkpoint file."""
raise NotImplementedError()
def get_logs_subdir(self):
"""Returns the subdirectory name for the logs directory."""
raise NotImplementedError()
def exists(self, path):
"""Returns True if the path exists in the store."""
raise NotImplementedError()
def read(self, path):
"""Returns the contents of the path as bytes."""
raise NotImplementedError()
def write_text(self, path, text):
"""Write text file to path."""
raise NotImplementedError()
def get_local_output_dir_fn(self, run_id):
raise NotImplementedError()
def sync_fn(self, run_id):
"""Returns a function that synchronises given path recursively into run path for `run_id`."""
raise NotImplementedError()
def to_remote(self, run_id, dataset_idx):
"""Returns a view of the store that can execute in a remote environment without Horoovd deps."""
attrs = self._remote_attrs(run_id, dataset_idx)
class RemoteStore(object):
def __init__(self):
for name, attr in attrs.items():
setattr(self, name, attr)
return RemoteStore()
def _remote_attrs(self, run_id, dataset_idx):
return {
'train_data_path': self.get_train_data_path(dataset_idx),
'val_data_path': self.get_val_data_path(dataset_idx),
'test_data_path': self.get_test_data_path(dataset_idx),
'saving_runs': self.saving_runs(),
'runs_path': self.get_runs_path(),
'run_path': self.get_run_path(run_id),
'checkpoint_path': self.get_checkpoint_path(run_id),
'logs_path': self.get_logs_path(run_id),
'checkpoint_filename': self.get_checkpoint_filename(),
'logs_subdir': self.get_logs_subdir(),
'get_local_output_dir': self.get_local_output_dir_fn(run_id),
'sync': self.sync_fn(run_id)
}
@staticmethod
def create(prefix_path, *args, **kwargs):
if HDFSStore.matches(prefix_path):
return HDFSStore(prefix_path, *args, **kwargs)
elif is_databricks() and DBFSLocalStore.matches_dbfs(prefix_path):
return DBFSLocalStore(prefix_path, *args, **kwargs)
else:
return FilesystemStore(prefix_path, *args, **kwargs)
class AbstractFilesystemStore(Store):
"""Abstract class for stores that use a filesystem for underlying storage."""
def __init__(self, prefix_path, train_path=None, val_path=None, test_path=None,
runs_path=None, save_runs=True, storage_options=None, **kwargs):
self.prefix_path = self.get_full_path(prefix_path)
self._train_path = self._get_full_path_or_default(train_path, 'intermediate_train_data')
self._val_path = self._get_full_path_or_default(val_path, 'intermediate_val_data')
self._test_path = self._get_full_path_or_default(test_path, 'intermediate_test_data')
self._runs_path = self._get_full_path_or_default(runs_path, 'runs')
self._save_runs = save_runs
self.storage_options = storage_options
super().__init__()
def exists(self, path):
return self.fs.exists(self.get_localized_path(path)) or self.fs.isdir(path)
def read(self, path):
with self.fs.open(self.get_localized_path(path), 'rb') as f:
return f.read()
def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
"""Reads the checkpoint file of the keras model into model bytes and returns the base 64
encoded model bytes.
:param ckpt_path: A string of path to the checkpoint file.
:param model: A keras model. This parameter will be used in DBFSLocalStore\
.read_serialized_keras_model() when the ckpt_path only contains model weights.
:param custom_objects: This parameter will be used in DBFSLocalStore\
.read_serialized_keras_model() when loading the keras model.
:return: the base 64 encoded model bytes of the checkpoint model.
"""
from horovod.runner.common.util import codec
model_bytes = self.read(ckpt_path)
return codec.dumps_base64(model_bytes)
def write_text(self, path, text):
with self.fs.open(self.get_localized_path(path), 'w') as f:
f.write(text)
def is_parquet_dataset(self, path):
try:
dataset = self.get_parquet_dataset(path)
return dataset is not None
except:
return False
def get_parquet_dataset(self, path):
return pq.ParquetDataset(self.get_localized_path(path), filesystem=self.fs)
def get_train_data_path(self, idx=None):
return '{}.{}'.format(self._train_path, idx) if idx is not None else self._train_path
def get_val_data_path(self, idx=None):
return '{}.{}'.format(self._val_path, idx) if idx is not None else self._val_path
def get_test_data_path(self, idx=None):
return '{}.{}'.format(self._test_path, idx) if idx is not None else self._test_path
def get_data_metadata_path(self, path):
localized_path = self.get_localized_path(path)
if localized_path.endswith('/'):
localized_path = localized_path[:-1] # Remove the slash at the end if there is one
metadata_cache = localized_path+"_cached_metadata.pkl"
return metadata_cache
def saving_runs(self):
return self._save_runs
def get_runs_path(self):
return self._runs_path
def get_run_path(self, run_id):
return os.path.join(self.get_runs_path(), run_id)
def get_checkpoint_path(self, run_id):
return os.path.join(self.get_run_path(run_id), self.get_checkpoint_filename()) \
if self._save_runs else None
def get_checkpoints(self, run_id, suffix='.ckpt'):
checkpoint_dir = self.get_localized_path(self.get_checkpoint_path(run_id))
filenames = self.fs.ls(checkpoint_dir)
return sorted([name for name in filenames if name.endswith(suffix)])
def get_logs_path(self, run_id):
return os.path.join(self.get_run_path(run_id), self.get_logs_subdir()) \
if self._save_runs else None
def get_checkpoint_filename(self):
return 'checkpoint'
def get_logs_subdir(self):
return 'logs'
def _get_full_path_or_default(self, path, default_key):
if path is not None:
return self.get_full_path(path)
return self._get_path(default_key)
def _get_path(self, key):
return os.path.join(self.prefix_path, key)
def get_local_output_dir_fn(self, run_id):
@contextlib.contextmanager
def local_run_path():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
return local_run_path
def get_localized_path(self, path):
raise NotImplementedError()
def get_full_path(self, path):
raise NotImplementedError()
def get_full_path_fn(self):
raise NotImplementedError()
@property
def fs(self):
raise NotImplementedError()
class FilesystemStore(AbstractFilesystemStore):
"""Concrete filesystems store that delegates to `fsspec`."""
def __init__(self, prefix_path, *args, **kwargs):
self.storage_options = kwargs['storage_options'] if 'storage_options' in kwargs else {}
self.prefix_path = prefix_path
self._fs, self.protocol = self._get_fs_and_protocol()
std_params = ['train_path', 'val_path', 'test_path', 'runs_path', 'save_runs', 'storage_options']
params = dict((k, kwargs[k]) for k in std_params if k in kwargs)
super().__init__(prefix_path, *args, **params)
def sync_fn(self, run_id):
run_path = self.get_run_path(run_id)
def fn(local_run_path):
self.fs.put(local_run_path, run_path, recursive=True, overwrite=True)
return fn
def get_filesystem(self):
return self.fs
def get_localized_path(self, path):
_, lpath = split_protocol(path)
return lpath
def get_full_path(self, path):
return self.get_full_path_fn()(path)
def get_full_path_fn(self):
def get_path(path):
protocol, _ = split_protocol(path)
if protocol is not None:
return path
return pathlib.Path(os.path.abspath(path)).as_uri()
return get_path
@property
def fs(self):
return self._fs
#@staticmethod
def _get_fs_and_protocol(self):
protocol, path = split_protocol(self.prefix_path)
fs = fsspec.filesystem(protocol, **self.storage_options)
return fs, protocol
@classmethod
def matches(cls, path):
return True
class LocalStore(FilesystemStore):
"""Uses the local filesystem as a store of intermediate data and training artifacts.
This class is deprecated and now just resolves to FilesystemStore.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class HDFSStore(AbstractFilesystemStore):
"""Uses HDFS as a store of intermediate data and training artifacts.
Initialized from a `prefix_path` that can take one of the following forms:
1. "hdfs://namenode01:8020/user/test/horovod"
2. "hdfs:///user/test/horovod"
3. "/user/test/horovod"
The full path (including prefix, host, and port) will be used for all reads and writes to HDFS through Spark. If
host and port are not provided, they will be omitted. If prefix is not provided (case 3), it will be prefixed to
the full path regardless.
The localized path (without prefix, host, and port) will be used for interaction with PyArrow. Parsed host and port
information will be used to initialize PyArrow `HadoopFilesystem` if they are not provided through the `host` and
`port` arguments to this initializer. These parameters will default to `default` and `0` if neither the path URL
nor the arguments provide this information.
"""
FS_PREFIX = 'hdfs://'
URL_PATTERN = '^(?:(.+://))?(?:([^/:]+))?(?:[:]([0-9]+))?(?:(.+))?$'
def __init__(self, prefix_path,
host=None, port=None, user=None, kerb_ticket=None,
driver='libhdfs', extra_conf=None, *args, **kwargs):
prefix, url_host, url_port, path, path_offset = self.parse_url(prefix_path)
self._check_url(prefix_path, prefix, path)
self._url_prefix = prefix_path[:path_offset] if prefix else self.FS_PREFIX
host = host or url_host or 'default'
port = port or url_port or 0
self._hdfs_kwargs = dict(host=host,
port=port,
user=user,
kerb_ticket=kerb_ticket,
extra_conf=extra_conf)
if LooseVersion(pa.__version__) < LooseVersion('0.17.0'):
self._hdfs_kwargs['driver'] = driver
self._hdfs = self._get_filesystem_fn()()
super(HDFSStore, self).__init__(prefix_path, *args, **kwargs)
def parse_url(self, url):
match = re.search(self.URL_PATTERN, url)
prefix = match.group(1)
host = match.group(2)
port = match.group(3)
if port is not None:
port = int(port)
path = match.group(4)
path_offset = match.start(4)
return prefix, host, port, path, path_offset
def get_full_path(self, path):
if not self.matches(path):
return self._url_prefix + path
return path
def get_full_path_fn(self):
prefix = self._url_prefix
def get_path(path):
return prefix + path
return get_path
@property
def fs(self):
return self._hdfs
def sync_fn(self, run_id):
class SyncState(object):
def __init__(self):
self.fs = None
self.uploaded = {}
state = SyncState()
get_filesystem = self._get_filesystem_fn()
hdfs_root_path = self.get_run_path(run_id)
def fn(local_run_path):
if state.fs is None:
state.fs = get_filesystem()
hdfs = state.fs
uploaded = state.uploaded
# We need to swap this prefix from the local path with the absolute path, +1 due to
# including the trailing slash
prefix = len(local_run_path) + 1
for local_dir, dirs, files in os.walk(local_run_path):
hdfs_dir = os.path.join(hdfs_root_path, local_dir[prefix:])
for file in files:
local_path = os.path.join(local_dir, file)
modified_ts = int(os.path.getmtime(local_path))
if local_path in uploaded:
last_modified_ts = uploaded.get(local_path)
if modified_ts <= last_modified_ts:
continue
hdfs_path = os.path.join(hdfs_dir, file)
with open(local_path, 'rb') as f:
hdfs.upload(hdfs_path, f)
uploaded[local_path] = modified_ts
return fn
def _get_filesystem_fn(self):
hdfs_kwargs = self._hdfs_kwargs
def fn():
return pa.hdfs.connect(**hdfs_kwargs)
return fn
def _check_url(self, url, prefix, path):
print('_check_url: {}'.format(prefix))
if prefix is not None and prefix != self.FS_PREFIX:
raise ValueError('Mismatched HDFS namespace for URL: {}. Found {} but expected {}'
.format(url, prefix, self.FS_PREFIX))
if not path:
raise ValueError('Failed to parse path from URL: {}'.format(url))
def get_localized_path(self, path):
if self.matches(path):
return path[len(self._url_prefix):]
return path
@classmethod
def matches(cls, path):
return path.startswith(cls.FS_PREFIX)
class DBFSLocalStore(FilesystemStore):
"""Uses Databricks File System (DBFS) local file APIs as a store of intermediate data and
training artifacts.
Initialized from a `prefix_path` starts with `/dbfs/...`, `file:///dbfs/...` or `dbfs:/...`, see
https://docs.databricks.com/data/databricks-file-system.html#local-file-apis.
"""
def __init__(self, prefix_path, *args, **kwargs):
prefix_path = self.normalize_path(prefix_path)
if not prefix_path.startswith("/dbfs/"):
warnings.warn("The provided prefix_path might be ephemeral: {} Please provide a "
"`prefix_path` starting with `/dbfs/...`".format(prefix_path))
super(DBFSLocalStore, self).__init__(prefix_path, *args, **kwargs)
@classmethod
def matches_dbfs(cls, path):
return path.startswith("dbfs:/") or path.startswith("/dbfs/") or path.startswith("file:///dbfs/")
@staticmethod
def normalize_path(path):
"""
Normalize the path to the form `/dbfs/...`
"""
if path.startswith("dbfs:/"):
return "/dbfs" + path[5:]
if path.startswith("file:///dbfs/"):
return path[7:]
return path
def get_checkpoint_filename(self):
# Use the default Tensorflow SavedModel format in TF 2.x. In TF 1.x, the SavedModel format
# is used by providing `save_weights_only=True` to the ModelCheckpoint() callback.
return 'checkpoint.tf'
def read_serialized_keras_model(self, ckpt_path, model, custom_objects):
"""
Returns serialized keras model.
The parameter `model` is for providing the model structure when the checkpoint file only
contains model weights.
"""
import tensorflow
from tensorflow import keras
from horovod.spark.keras.util import TFKerasUtil
if LooseVersion(tensorflow.__version__) < LooseVersion("2.0.0"):
model.load_weights(ckpt_path)
else:
with keras.utils.custom_object_scope(custom_objects):
model = keras.models.load_model(ckpt_path)
return TFKerasUtil.serialize_model(model)