-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
remote.py
365 lines (304 loc) · 15.6 KB
/
remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import io
import math
import os
import h5py
import tensorflow as tf
from distutils.version import LooseVersion
from horovod.spark.common import constants
from horovod.spark.common.store import DBFSLocalStore
from horovod.spark.common.util import _get_assigned_gpu_or_default
from horovod.runner.common.util import codec
PETASTORM_HDFS_DRIVER = constants.PETASTORM_HDFS_DRIVER
TOTAL_BUFFER_MEMORY_CAP_GIB = constants.TOTAL_BUFFER_MEMORY_CAP_GIB
BYTES_PER_GIB = constants.BYTES_PER_GIB
def RemoteTrainer(estimator, metadata, keras_utils, run_id, dataset_idx):
# Estimator parameters
label_columns = estimator.getLabelCols()
feature_columns = estimator.getFeatureCols()
user_callbacks = estimator.getCallbacks()
batch_size = estimator.getBatchSize()
val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize() else batch_size
epochs = estimator.getEpochs()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
validation_steps_per_epoch = estimator.getValidationStepsPerEpoch()
sample_weight_col = estimator.getSampleWeightCol()
custom_objects = estimator.getCustomObjects()
should_validate = estimator.getValidation()
user_shuffle_buffer_size = estimator.getShufflingBufferSize()
user_verbose = estimator.getVerbose()
checkpoint_callback = estimator.getCheckpointCallback()
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()
# Model parameters
input_shapes, output_shapes = estimator.get_model_shapes()
output_names = estimator.getModel().output_names
label_shapes = estimator.getLabelShapes()
# Keras implementation
keras_module = keras_utils.keras()
floatx = keras_module.backend.floatx()
get_horovod = keras_utils.horovod_fn()
get_keras = keras_utils.keras_fn()
make_dataset = keras_utils.make_dataset_fn(
feature_columns=feature_columns,
label_columns=label_columns,
sample_weight_col=sample_weight_col,
metadata=metadata,
input_shapes=input_shapes,
label_shapes=label_shapes if label_shapes else output_shapes,
output_names=output_names)
fit = keras_utils.fit_fn(epochs)
transformation_fn = estimator.getTransformationFn()
transformation = transformation_fn if transformation_fn else None
# Utility functions
deserialize_keras_model = _deserialize_keras_model_fn()
calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn()
pin_gpu = _pin_gpu_fn()
# Storage
store = estimator.getStore()
is_dbfs = isinstance(store, DBFSLocalStore)
remote_store = store.to_remote(run_id, dataset_idx)
def SyncCallback(root_path, sync_to_store_fn, keras):
class _SyncCallback(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
sync_to_store_fn(root_path)
return _SyncCallback()
@contextlib.contextmanager
def empty_batch_reader():
yield None
def train(serialized_model, train_rows, val_rows, avg_row_size):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod as _horovod
k = get_keras()
k.backend.set_floatx(floatx)
hvd = get_horovod()
hvd.init()
pin_gpu(hvd, tf, k)
if not user_shuffle_buffer_size:
shuffle_buffer_size = calculate_shuffle_buffer_size(
hvd, avg_row_size, train_rows / hvd.size())
else:
shuffle_buffer_size = user_shuffle_buffer_size
# needs to be deserialized in the with scope
with k.utils.custom_object_scope(custom_objects):
model = deserialize_keras_model(
serialized_model, lambda x: hvd.load_model(x))
# Horovod: adjust learning rate based on number of processes.
scaled_lr = k.backend.get_value(model.optimizer.lr) * hvd.size()
k.backend.set_value(model.optimizer.lr, scaled_lr)
# Verbose mode 1 will print a progress bar
verbose = user_verbose if hvd.rank() == 0 else 0
transform_spec = None
if transformation:
transform_spec = TransformSpec(transformation)
# The inital_lr needs to be set to scaled learning rate in the checkpointing callbacks.
for callback in user_callbacks:
if isinstance(callback, _horovod._keras.callbacks.LearningRateScheduleCallbackImpl):
callback.initial_lr = scaled_lr
with remote_store.get_local_output_dir() as run_output_dir:
callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),
# Horovod: average metrics among workers at the end of every epoch.
#
# Note: This callback must be in the list before the ReduceLROnPlateau,
# TensorBoard, or other metrics-based callbacks.
hvd.callbacks.MetricAverageCallback(),
]
callbacks += user_callbacks
# Horovod: save checkpoints only on the first worker to prevent other workers from
# corrupting them.
if hvd.rank() == 0:
ckpt_file = os.path.join(run_output_dir, remote_store.checkpoint_filename)
logs_dir = os.path.join(run_output_dir, remote_store.logs_subdir)
# This callback checkpoints the model that ultimately is wrapped and returned after
# Estimator.fit is called.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
_checkpoint_callback.filepath = ckpt_file
else:
if is_dbfs and LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
# Because DBFS local file APIs does not support random write which is
# required by h5 format, save_weights_only=True is needed for switching
# to the TensorFlow SavedModel format.
_checkpoint_callback = k.callbacks.ModelCheckpoint(ckpt_file,
save_weights_only=True)
else:
_checkpoint_callback = k.callbacks.ModelCheckpoint(ckpt_file)
callbacks.append(_checkpoint_callback)
if remote_store.saving_runs:
callbacks.append(k.callbacks.TensorBoard(logs_dir))
callbacks.append(SyncCallback(run_output_dir, remote_store.sync, k))
if train_steps_per_epoch is None:
steps_per_epoch = int(math.ceil(train_rows / batch_size / hvd.size()))
else:
steps_per_epoch = train_steps_per_epoch
if validation_steps_per_epoch is None:
# math.ceil because if val_rows is smaller than val_batch_size we still get the at least
# one step. float(val_rows) because val_rows/val_batch_size evaluates to zero before
# math.ceil
validation_steps = int(math.ceil(float(val_rows) / val_batch_size / hvd.size())) \
if should_validate else None
else:
validation_steps = validation_steps_per_epoch
schema_fields = feature_columns + label_columns
if sample_weight_col:
schema_fields.append(sample_weight_col)
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
reader_factory_kwargs = dict()
if transform_spec:
reader_factory = make_reader
reader_factory_kwargs['pyarrow_serialize'] = True
is_batch_reader = False
else:
reader_factory = make_batch_reader
is_batch_reader = True
# Petastorm: read data from the store with the correct shard for this rank
# setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
# unequal number of samples
with reader_factory(remote_store.train_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type=reader_pool_type,
workers_count=train_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
**reader_factory_kwargs) as train_reader:
with reader_factory(remote_store.val_data_path,
num_epochs=None,
cur_shard=hvd.rank(),
reader_pool_type=reader_pool_type,
workers_count=val_reader_worker_count,
shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
**reader_factory_kwargs) \
if should_validate else empty_batch_reader() as val_reader:
train_data = make_dataset(train_reader, batch_size, shuffle_buffer_size,
is_batch_reader, shuffle=True)
val_data = make_dataset(val_reader, val_batch_size, shuffle_buffer_size,
is_batch_reader, shuffle=False) \
if val_reader else None
history = fit(model, train_data, val_data, steps_per_epoch,
validation_steps, callbacks, verbose)
# Dataset API usage currently displays a wall of errors upon termination.
# This global model registration ensures clean termination.
# Tracked in https://github.com/tensorflow/tensorflow/issues/24570
globals()['_DATASET_FINALIZATION_HACK'] = model
if hvd.rank() == 0:
if is_dbfs:
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
model.load_weights(ckpt_file)
else:
# needs to be deserialized in the with scope
with k.utils.custom_object_scope(custom_objects):
model = k.models.load_model(ckpt_file)
serialized_model = keras_utils.serialize_model(model)
else:
with open(ckpt_file, 'rb') as f:
serialized_model = codec.dumps_base64(f.read())
return history.history, serialized_model, hvd.size()
return train
def _deserialize_keras_model_fn():
def deserialize_keras_model(model_bytes, load_model_fn):
"""Deserialize model from byte array encoded in base 64."""
model_bytes = codec.loads_base64(model_bytes)
bio = io.BytesIO(model_bytes)
with h5py.File(bio, 'r') as f:
return load_model_fn(f)
return deserialize_keras_model
def _calculate_shuffle_buffer_size_fn():
def calculate_shuffle_buffer_size(hvd, avg_row_size, train_row_count_per_worker):
"""
Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling
buffer such that on a single machine, among all the workers on that machine, at most
memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size
is identical among all the workers.
example 1:
memory_cap_gb = 4
machine1: 8 workers
machine2: 3 workers
shuffle_buffer_size = 0.5 GB
example 2:
memory_cap_gb = 4
machine1: 2 workers
machine2: 3 workers
shuffle_buffer_size = 1 GB
example 3:
memory_cap_gb = 4
machine1: 2 workers
machine2: 8 workers
machine3: 5 workers
shuffle_buffer_size = 0.5 GB
"""
local_size = hvd.local_size()
local_sizes = hvd.allgather([local_size])
max_local_size = int(max(local_sizes))
if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB:
shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size
else:
shuffle_buffer_size = BYTES_PER_GIB / avg_row_size
return int(min(shuffle_buffer_size, train_row_count_per_worker))
return calculate_shuffle_buffer_size
def _pin_gpu_fn():
# Horovod: pin GPU to be used to process local rank (one GPU per process)
return _pin_gpu_tensorflow2_fn() if LooseVersion(tf.__version__) >= LooseVersion('2.0.0') \
else _pin_gpu_tensorflow1_fn()
def _pin_gpu_tensorflow2_fn():
def fn(hvd, tf, keras):
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(
gpus[_get_assigned_gpu_or_default(default=hvd.local_rank())], 'GPU')
return fn
def _pin_gpu_tensorflow1_fn():
def fn(hvd, tf, keras):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = \
str(_get_assigned_gpu_or_default(default=hvd.local_rank()))
keras.backend.set_session(tf.Session(config=config))
return fn
def _pin_cpu_fn():
return _pin_cpu_tensorflow2_fn() if LooseVersion(tf.__version__) >= LooseVersion('2.0.0') \
else _pin_cpu_tensorflow1_fn()
def _pin_cpu_tensorflow2_fn():
def fn(tf, keras):
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf.config.threading.set_inter_op_parallelism_threads(1)
tf.config.threading.set_intra_op_parallelism_threads(1)
return fn
def _pin_cpu_tensorflow1_fn():
def fn(tf, keras):
config = tf.ConfigProto(device_count={'GPU': 0})
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = 1
keras.backend.set_session(tf.Session(config=config))
return fn