/
multiprocess_parallel_updater.py
480 lines (398 loc) · 15.6 KB
/
multiprocess_parallel_updater.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
import multiprocessing
import warnings
import six
from chainer.backends import cuda
from chainer.dataset import convert
from chainer import reporter
from chainer.training.updaters import standard_updater
try:
from cupy.cuda import nccl
_available = True
except ImportError:
_available = False
import numpy
class _Worker(multiprocessing.Process):
def __init__(self, proc_id, pipe, master):
super(_Worker, self).__init__()
self.proc_id = proc_id
self.pipe = pipe
self.converter = master.converter
self.model = master._master
self.device = master._devices[proc_id]
self.iterator = master._mpu_iterators[proc_id]
self.n_devices = len(master._devices)
def setup(self):
_, comm_id = self.pipe.recv()
self.comm = nccl.NcclCommunicator(self.n_devices, comm_id,
self.proc_id)
self.model.to_gpu(self.device)
self.reporter = reporter.Reporter()
self.reporter.add_observer('main', self.model)
self.reporter.add_observers('main',
self.model.namedlinks(skipself=True))
def run(self):
dev = cuda.Device(self.device)
dev.use()
self.setup()
gp = None
while True:
job, data = self.pipe.recv()
if job == 'finalize':
dev.synchronize()
break
if job == 'update':
# For reducing memory
self.model.cleargrads()
batch = self.converter(self.iterator.next(), self.device)
observation = {}
with self.reporter.scope(observation):
loss = _calc_loss(self.model, batch)
self.model.cleargrads()
loss.backward()
del loss
gg = gather_grads(self.model)
nccl_data_type = _get_nccl_data_type(gg.dtype)
null_stream = cuda.Stream.null
self.comm.reduce(gg.data.ptr, gg.data.ptr, gg.size,
nccl_data_type, nccl.NCCL_SUM, 0,
null_stream.ptr)
del gg
self.model.cleargrads()
gp = gather_params(self.model)
nccl_data_type = _get_nccl_data_type(gp.dtype)
self.comm.bcast(gp.data.ptr, gp.size, nccl_data_type, 0,
null_stream.ptr)
scatter_params(self.model, gp)
gp = None
class MultiprocessParallelUpdater(standard_updater.StandardUpdater):
"""Implementation of a multiprocess parallel GPU Updater.
This is an implementation of :class:`Updater` that uses multiple GPUs
with multi-process data parallelism. It uses Nvidia NCCL for communication
between multiple GPUs.
It behaves similarly to
:class:`~chainer.training.updaters.StandardUpdater`.
The update routine is modified to support data-parallel
computation on multiple GPUs in one machine.
It is based on synchronous parallel SGD: it
parallelizes the gradient computation over a mini-batch, and updates the
parameters only in the main device.
It does not transfer the values collected by :class:`Reporter` in the sub
devices to the main device. So you can only see the reported values in
the main device.
Args:
iterators: List of dataset iterator for the training dataset. The
number of the iterators must be same to the number of GPUs you use.
optimizer: Optimizer to update parameters. The model should be attached
to the optimizer.
converter: Converter function to build input arrays. Each batch
extracted by the iterator is split equally between the devices and
then passed with corresponding ``device`` option to this function.
:func:`~chainer.dataset.concat_examples` is used by default.
devices: Dictionary or list of devices to which the training data is
sent. The master device will be the first one in the list or the
value attached to the key ``'main'``.
"""
def __init__(self, iterators, optimizer, converter=convert.concat_examples,
devices=None):
if not MultiprocessParallelUpdater.available():
raise Exception(
'NCCL is not enabled. MultiprocessParallelUpdater '
'requires NCCL.\n'
'Please reinstall CuPy after you install NCCL.\n'
'(see https://docs-cupy.chainer.org/en/latest/install.html)')
try:
cuda.cupy.cuda.driver.ctxGetCurrent()
_cuda_initialized = True
except cuda.cupy.cuda.driver.CUDADriverError:
# The context is not initialized, it will be fine.
_cuda_initialized = False
if _cuda_initialized:
raise RuntimeError(
'The CUDA context has been already initialized. '
'MultiprocessParallelUpdater assumes the context is '
'uninitialized. Please do not call CUDA API before '
'MultiprocessParallelUpdater creates processes.')
assert len(iterators) == len(devices)
for iterator in iterators[1:]:
assert len(iterator.dataset) == len(iterators[0].dataset)
# Correct optimizer parameters for new minibatch size
optim = optimizer.__class__.__name__
if optim in ('Adam', 'AdaGrad', 'RMSprop'):
optimizer.eps *= len(devices)
warnings.warn('optimizer.eps is changed to {} '
'by MultiprocessParallelUpdater for new batch size.'.
format(optimizer.eps))
elif optim in ('RMSpropGraves', 'AdaDelta'):
optimizer.eps *= len(devices) ** 2 # not quite right for AdaDelta
warnings.warn('optimizer.eps is changed to {} '
'by MultiprocessParallelUpdater for new batch size.'.
format(optimizer.eps))
elif hasattr(optimizer, 'lr'):
optimizer.lr /= len(devices)
warnings.warn('optimizer.lr is changed to {} '
'by MultiprocessParallelUpdater for new batch size.'.
format(optimizer.lr))
super(MultiprocessParallelUpdater, self).__init__(
iterator=iterators[0],
optimizer=optimizer,
converter=converter
)
if isinstance(devices, dict):
main = devices.pop('main')
devices = list(six.itervalues(devices))
devices = [main] + devices
elif isinstance(devices, (list, tuple)):
devices = list(devices)
else:
raise ValueError(
'devices argument should be either dict, list or tuple,'
' but {} was given.'.format(type(devices)))
if devices is None or any(device is None for device in devices):
raise ValueError('must specify GPU devices')
self._master = optimizer.target
self._devices = devices
self._mpu_iterators = iterators
self._initialized = False
self._pipes = []
self._workers = []
self.comm = None
@staticmethod
def available():
return _available
def _send_message(self, message):
for pipe in self._pipes:
pipe.send(message)
def setup_workers(self):
if self._initialized:
return
self._initialized = True
self._master.cleargrads()
for i in six.moves.range(1, len(self._devices)):
pipe, worker_end = multiprocessing.Pipe()
worker = _Worker(i, worker_end, self)
worker.start()
self._workers.append(worker)
self._pipes.append(pipe)
with cuda.Device(self._devices[0]):
self._master.to_gpu(self._devices[0])
if len(self._devices) > 1:
comm_id = nccl.get_unique_id()
self._send_message(("set comm_id", comm_id))
self.comm = nccl.NcclCommunicator(len(self._devices),
comm_id, 0)
def update_core(self):
self.setup_workers()
self._send_message(('update', None))
with cuda.Device(self._devices[0]):
# For reducing memory
self._master.cleargrads()
optimizer = self.get_optimizer('main')
batch = self.get_iterator('main').next()
batch = self.converter(batch, self._devices[0])
loss = _calc_loss(self._master, batch)
self._master.cleargrads()
loss.backward()
# NCCL: reduce grads
null_stream = cuda.Stream.null
if self.comm is not None:
gg = gather_grads(self._master)
nccl_data_type = _get_nccl_data_type(gg.dtype)
self.comm.reduce(gg.data.ptr, gg.data.ptr, gg.size,
nccl_data_type, nccl.NCCL_SUM,
0, null_stream.ptr)
scatter_grads(self._master, gg)
del gg
optimizer.update()
if self.comm is not None:
gp = gather_params(self._master)
nccl_data_type = _get_nccl_data_type(gp.dtype)
self.comm.bcast(gp.data.ptr, gp.size, nccl_data_type,
0, null_stream.ptr)
def finalize(self):
self._send_message(('finalize', None))
for worker in self._workers:
worker.join()
def _calc_loss(model, in_arrays):
if isinstance(in_arrays, tuple):
return model(*in_arrays)
elif isinstance(in_arrays, dict):
return model(**in_arrays)
else:
return model(in_arrays)
def size_num_grads(link):
"""Count total size of all gradient arrays of a given link
Args:
link (chainer.link.Link): Target link object.
"""
size = 0
num = 0
for param in link.params():
if param.size == 0:
continue
size += param.size
num += 1
return size, num
def _memcpy_gather():
return cuda.elementwise(
'raw T ptrs, raw X dtypes, raw Y info',
'raw float32 dst',
'''
int id_min = id_pre;
int id_max = num_src;
while (id_max - id_min > 1) {
int id = (id_max + id_min) / 2;
if (i < info[id]) id_max = id;
else id_min = id;
}
int id = id_min;
int i_dst = i;
int i_src = i;
if (id > 0) i_src -= info[id];
dst[i_dst] = 0;
if (ptrs[id] != NULL) {
if (dtypes[id] == 0) { // fp32
float *src = reinterpret_cast<float *>(ptrs[id]);
dst[i_dst] = src[i_src];
}
else { // fp16
float16 *src = reinterpret_cast<float16 *>(ptrs[id]);
dst[i_dst] = static_cast<float>(src[i_src]);
}
}
id_pre = id;
''',
'_memcpy_gather',
loop_prep='''
int num_src = info[0];
int id_pre = 0;
''')
def _gather(link, target):
size, num = size_num_grads(link)
ptrs = numpy.empty(num, dtype=numpy.uint64)
dtypes = numpy.empty(num, dtype=numpy.int8)
info = numpy.empty(num + 1, dtype=numpy.int32)
info[0] = 0
i = 0
for _, param in sorted(link.namedparams()):
if param.size == 0:
continue
ptrs[i] = 0 # NULL pointer
d = getattr(param, target)
if d is not None:
ptrs[i] = d.data.ptr
dtypes[i] = 0 # fp32
if param.dtype == numpy.float16:
dtypes[i] = 1 # fp16
info[i + 1] = info[i] + param.size
i += 1
info[0] = num
ptrs = cuda.to_gpu(ptrs)
dtypes = cuda.to_gpu(dtypes)
info = cuda.to_gpu(info)
return _memcpy_gather()(ptrs, dtypes, info, size=size)
def gather_grads(link):
"""Put together all gradient arrays and make a single array
Args:
link (chainer.link.Link): Target link object.
Return:
cupy.ndarray
"""
if link.xp is numpy:
raise RuntimeError('gather_grads works only on GPU.')
return _gather(link, "grad")
def gather_params(link):
"""Put together all gradient arrays and make a single array
Args:
link (chainer.link.Link): Target link object.
Return:
cupy.ndarray
"""
if link.xp is numpy:
raise RuntimeError('Link.gather_params works only on GPU.')
return _gather(link, "data")
def _memcpy_scatter():
return cuda.elementwise(
'raw T ptrs, raw X dtypes, raw Y info, raw float32 array',
'',
'''
int id_min = id_pre;
int id_max = num_src;
while (id_max - id_min > 1) {
int id = (id_max + id_min) / 2;
if (i < info[id]) id_max = id;
else id_min = id;
}
int id = id_min;
int i_src = i;
int i_dst = i;
if (id > 0) i_dst -= info[id];
if (ptrs[id] != NULL) {
if (dtypes[id] == 0) { // fp32
float *dst = reinterpret_cast<float *>(ptrs[id]);
dst[i_dst] = array[i_src];
}
else { // fp16
float16 *dst = reinterpret_cast<float16 *>(ptrs[id]);
dst[i_dst] = static_cast<float16>(array[i_src]);
}
}
id_pre = id;
''',
'_memcpy_scatter',
loop_prep='''
int num_src = info[0];
int id_pre = 0;
''')
def _scatter(link, array, target):
size, num = size_num_grads(link)
ptrs = numpy.zeros(num, dtype=numpy.uint64)
dtypes = numpy.zeros(num, dtype=numpy.int8)
info = numpy.zeros(num + 1, dtype=numpy.int32)
info[0] = 0
i = 0
for _, param in sorted(link.namedparams()):
if param.size == 0:
continue
ptrs[i] = 0 # NULL pointer
d = getattr(param, target)
if d is None:
d = cuda.cupy.zeros(param.shape, dtype=param.dtype)
setattr(param, target, d)
ptrs[i] = d.data.ptr
dtypes[i] = 0 # fp32
if param.dtype == numpy.float16:
dtypes[i] = 1 # fp16
info[i + 1] = info[i] + param.size
i += 1
if i != num:
raise()
info[0] = num
ptrs = cuda.to_gpu(ptrs)
dtypes = cuda.to_gpu(dtypes)
info = cuda.to_gpu(info)
return _memcpy_scatter()(ptrs, dtypes, info, array, size=size)
def scatter_grads(link, array):
"""Put back contents of the specified array to the related gradient arrays
Args:
link (chainer.link.Link): Target link object.
array (cupy.ndarray): gathered array created by gather_grads()
"""
return _scatter(link, array, "grad")
def scatter_params(link, array):
"""Put back contents of the specified array to the related gradient arrays
Args:
link (chainer.link.Link): Target link object.
array (cupy.ndarray): gathered array created by gather_params()
"""
return _scatter(link, array, "data")
def _get_nccl_data_type(dtype):
"""Get data type for NCCL"""
if dtype == numpy.float32:
nccl_data_type = nccl.NCCL_FLOAT
elif dtype == numpy.float16:
nccl_data_type = nccl.NCCL_HALF
elif dtype == numpy.float64:
nccl_data_type = nccl.NCCL_DOUBLE
else:
raise RuntimeError('Unexpected data type:{}'.format(dtype))
return nccl_data_type