/
remote.py
345 lines (281 loc) · 14.1 KB
/
remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import io
import os
import tempfile
import math
from distutils.version import LooseVersion
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from horovod.spark.common import constants
from horovod.spark.common.util import _get_assigned_gpu_or_default, to_list
from horovod.spark.common.store import DBFSLocalStore
from horovod.spark.lightning.util import deserialize_fn
PETASTORM_HDFS_DRIVER = constants.PETASTORM_HDFS_DRIVER
METRIC_PRINT_FREQUENCY = constants.METRIC_PRINT_FREQUENCY
TOTAL_BUFFER_MEMORY_CAP_GIB = constants.TOTAL_BUFFER_MEMORY_CAP_GIB
BYTES_PER_GIB = constants.BYTES_PER_GIB
CUSTOM_SPARSE = constants.CUSTOM_SPARSE
def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_rows, val_rows, avg_row_size, is_legacy):
# Estimator parameters
input_shapes = estimator.getInputShapes()
label_shapes = estimator.getLabelShapes()
feature_columns = estimator.getFeatureCols()
label_columns = estimator.getLabelCols()
sample_weight_col = estimator.getSampleWeightCol()
should_validate = estimator.getValidation()
batch_size = estimator.getBatchSize()
val_batch_size = estimator.getValBatchSize() if estimator.getValBatchSize() else batch_size
epochs = estimator.getEpochs()
user_shuffle_buffer_size = estimator.getShufflingBufferSize()
transformation_fn = estimator.getTransformationFn()
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
logger = estimator.getLogger()
log_every_n_steps = estimator.getLogEveryNSteps()
# Data reader parameters
train_reader_worker_count = estimator.getTrainReaderNumWorker()
val_reader_worker_count = estimator.getValReaderNumWorker()
reader_pool_type = estimator.getReaderPoolType()
# Utility functions
deserialize = deserialize_fn()
calculate_shuffle_buffer_size = _calculate_shuffle_buffer_size_fn(
train_rows, avg_row_size, user_shuffle_buffer_size)
schema_fields = feature_columns + label_columns
if sample_weight_col:
schema_fields.append(sample_weight_col)
dataloader_cls = _create_dataloader(feature_columns, input_shapes, metadata)
make_petastorm_reader = _make_petastorm_reader_fn(transformation, schema_fields,
batch_size, calculate_shuffle_buffer_size,
dataloader_cls)
# Storage
store = estimator.getStore()
remote_store = store.to_remote(run_id, dataset_idx)
def train(serialized_model):
import horovod.torch as hvd
# Horovod: initialize library.
hvd.init()
with tempfile.TemporaryDirectory() as last_ckpt_dir, remote_store.get_local_output_dir() as run_output_dir:
last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
if ckpt_bytes:
with open(last_ckpt_file, 'wb') as f:
f.write(ckpt_bytes)
# TODO: Pass the logger from estimator constructor
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
# Use default logger if no logger is supplied
train_logger = logger
if train_logger is None:
train_logger = TensorBoardLogger(logs_path)
# TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
# ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
# os.makedirs(ckpt_path, exist_ok=True)
# model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
# callbacks.append(model_checkpoint_callback)
is_model_checkpoint_callback_exist = False
if callbacks is not None:
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
is_model_checkpoint_callback_exist = True
break
model = deserialize(serialized_model)
_train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else 1.0
_val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0
cuda_available = torch.cuda.is_available()
# We need to check all ranks have same device type for traning.
# Horovod doesn't support heterogeneous allreduce for gradients.
cuda_avail_list = hvd.allgather_object(cuda_available, name='device type')
if hvd.rank() == 0:
assert cuda_avail_list.count(cuda_available) == hvd.size(), "All ranks don't have same device type!"
if cuda_available:
# Horovod: pin GPU to local rank or the assigned GPU from spark.
torch.cuda.set_device(_get_assigned_gpu_or_default(default=hvd.local_rank()))
# Move model to GPU.
model.cuda()
_num_gpus = num_gpus
if _num_gpus is None:
_num_gpus = 1 if cuda_available else 0
kwargs = {'accelerator': 'horovod',
'gpus': _num_gpus,
'callbacks': callbacks,
'max_epochs': epochs,
'limit_train_batches': _train_steps_per_epoch,
'limit_val_batches': _val_steps_per_epoch,
'logger': train_logger,
'log_every_n_steps': log_every_n_steps,
'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None),
'checkpoint_callback': is_model_checkpoint_callback_exist,
'num_sanity_val_steps': 0,
'reload_dataloaders_every_epoch': False
}
print("Creating trainer with: \n ", kwargs)
trainer = Trainer(**kwargs)
print(f"pytorch_lightning version={pl.__version__}")
# print row group
# pq.ParquetFile(remote_store.train_data_path)
# for rowgroup in range(pq_file.metadata.num_row_groups):
# row_group = pq_file.metadata.row_group(rowgroup)
# print(row_group)
with make_petastorm_reader(model, remote_store.train_data_path, 'train_dataloader',
train_reader_worker_count, reader_pool_type), \
make_petastorm_reader(model, remote_store.val_data_path, 'val_dataloader',
val_reader_worker_count, reader_pool_type, should_validate):
trainer.fit(model)
serialized_checkpoint = io.BytesIO()
module = model if not is_legacy else model._model
# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}
torch.save(output, serialized_checkpoint)
serialized_checkpoint.seek(0)
return serialized_checkpoint
return train
def _reset_loader(loader):
from petastorm.pytorch import BatchedDataLoader
from pytorch_lightning.trainer.supporters import CombinedLoader
if isinstance(loader, CombinedLoader):
for loader in loader.loaders:
loader.reader.reset()
else:
loader.reader.reset()
# TODO: enable this when petastorm loader supports reset before epoch ends.
def _make_reset_callbacks():
class ResetCallback(Callback):
def on_train_end(self, trainer, model):
_reset_loader(trainer.train_dataloader)
def on_validation_end(self, trainer, model):
for loader in trainer.val_dataloaders:
loader.reader.reset()
def on_sanity_check_end(self, trainer, model):
for loader in trainer.val_dataloaders:
_reset_loader(loader)
return [ResetCallback()]
def _make_petastorm_reader_fn(transformation, schema_fields, batch_size, calculate_shuffle_buffer_size, dataloader_cls):
@contextlib.contextmanager
def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count, reader_pool_type, should_read=True):
from petastorm import TransformSpec, make_reader, make_batch_reader
import horovod.torch as hvd
is_loader_overridden = False
if LooseVersion(pl.__version__) >= LooseVersion('1.0.0'):
from pytorch_lightning.utilities.model_helpers import is_overridden
is_loader_overridden = is_overridden(dataloader_attr, model)
if not should_read or is_loader_overridden:
yield
return
transform_spec = TransformSpec(transformation) if transformation else None
# In general, make_batch_reader is faster than make_reader for reading the dataset.
# However, we found out that make_reader performs data transformations much faster than
# make_batch_reader with parallel worker processes. Therefore, the default reader
# we choose is make_batch_reader unless there are data transformations.
reader_factory_kwargs = dict()
if transform_spec:
reader_factory = make_reader
reader_factory_kwargs['pyarrow_serialize'] = True
else:
reader_factory = make_batch_reader
# Petastorm: read data from the store with the correct shard for this rank
# setting num_epochs=None will cause an infinite iterator
# and enables ranks to perform training and validation with
# unequal number of samples
with reader_factory(data_path,
num_epochs=1,
cur_shard=hvd.rank(),
shard_count=hvd.size(),
reader_pool_type=reader_pool_type,
workers_count=reader_worker_count,
hdfs_driver=PETASTORM_HDFS_DRIVER,
schema_fields=schema_fields,
transform_spec=transform_spec,
**reader_factory_kwargs) as reader:
def dataloader_fn():
return dataloader_cls(reader, batch_size=batch_size,
shuffling_queue_capacity=calculate_shuffle_buffer_size())
try:
setattr(model, dataloader_attr, dataloader_fn)
yield
finally:
setattr(model, dataloader_attr, None)
return make_petastorm_reader
def _calculate_shuffle_buffer_size_fn(train_rows, avg_row_size, user_shuffle_buffer_size):
def calculate_shuffle_buffer_size():
"""
Determines the shuffling buffer size such that each worker gets at most 1GB for shuffling
buffer such that on a single machine, among all the workers on that machine, at most
memory_cap_gb GB are allocated for shuffling buffer. Also, it ensures that the buffer size
is identical among all the workers.
example 1:
memory_cap_gb = 4
machine1: 8 workers
machine2: 3 workers
shuffle_buffer_size = 0.5 GB
example 2:
memory_cap_gb = 4
machine1: 2 workers
machine2: 3 workers
shuffle_buffer_size = 1 GB
example 3:
memory_cap_gb = 4
machine1: 2 workers
machine2: 8 workers
machine3: 5 workers
shuffle_buffer_size = 0.5 GB
"""
import horovod.torch as hvd
if user_shuffle_buffer_size:
return user_shuffle_buffer_size
local_size = hvd.local_size()
local_sizes = hvd.allgather(torch.tensor([local_size]))
max_local_size = torch.max(local_sizes).item()
if max_local_size > TOTAL_BUFFER_MEMORY_CAP_GIB:
shuffle_buffer_size = TOTAL_BUFFER_MEMORY_CAP_GIB * BYTES_PER_GIB / avg_row_size / max_local_size
else:
shuffle_buffer_size = BYTES_PER_GIB / avg_row_size
return int(min(shuffle_buffer_size, train_rows / hvd.size()))
return calculate_shuffle_buffer_size
def _create_dataloader(feature_columns, input_shapes, metadata):
from petastorm.pytorch import BatchedDataLoader
shape_dict = {col:shape for col, shape in zip(feature_columns, input_shapes)}
prepare_data = _prepare_data_fn(metadata)
class _DataLoader(BatchedDataLoader):
def _yield_batches(self, keys):
for batch in super()._yield_batches(keys):
batch = {
k: prepare_data(k, v).reshape(shape_dict[k]) if k in shape_dict else v
for k, v in batch.items()
}
yield batch
return _DataLoader
def _prepare_data_fn(metadata):
def prepare_data(col_name, rows):
if col_name not in metadata:
return rows
intermediate_format = metadata[col_name]['intermediate_format']
if intermediate_format != CUSTOM_SPARSE:
return rows
shape = metadata[col_name]['shape']
num_rows = rows.shape[0]
dense_rows = torch.zeros([num_rows, shape])
for r in range(num_rows):
size = rows[r][0].long()
dense_rows[r][rows[r][1:size + 1].long()] = \
rows[r][size + 1:2 * size + 1]
return dense_rows
return prepare_data