/
operations.py
531 lines (417 loc) 路 18.3 KB
/
operations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A set of basic tensor ops compatible with tpu, gpu, and multigpu
"""
from functools import update_wrapper
from typing import Any, Mapping
import torch
from torch.distributed import ReduceOp
from ..state import AcceleratorState
from .dataclasses import DistributedType, TensorInformation
from .imports import is_tpu_available
from .versions import is_torch_version
if is_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
def is_torch_tensor(tensor):
return isinstance(tensor, torch.Tensor)
def is_tensor_information(tensor_info):
return isinstance(tensor_info, TensorInformation)
def honor_type(obj, generator):
"""
Cast a generator to the same type as obj (list, tuple or namedtuple)
"""
try:
return type(obj)(generator)
except TypeError:
# Some objects may not be able to instantiate from a generator directly
return type(obj)(*list(generator))
def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
"""
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
Args:
func (`callable`):
The function to recursively apply.
data (nested list/tuple/dictionary of `main_type`):
The data on which to apply `func`
*args:
Positional arguments that will be passed to `func` when applied on the unpacked data.
main_type (`type`, *optional*, defaults to `torch.Tensor`):
The base type of the objects to which apply `func`.
error_on_other_type (`bool`, *optional*, defaults to `False`):
Whether to return an error or not if after unpacking `data`, we get on an object that is not of type
`main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged.
**kwargs:
Keyword arguments that will be passed to `func` when applied on the unpacked data.
Returns:
The same data structure as `data` with `func` applied to every object of type `main_type`.
"""
if isinstance(data, (tuple, list)):
return honor_type(
data,
(
recursively_apply(
func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
)
for o in data
),
)
elif isinstance(data, Mapping):
return type(data)(
{
k: recursively_apply(
func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs
)
for k, v in data.items()
}
)
elif test_type(data):
return func(data, *args, **kwargs)
elif error_on_other_type:
raise TypeError(
f"Can't apply {func.__name__} on object of type {type(data)}, only of nested list/tuple/dicts of objects "
f"that satisfy {test_type.__name__}."
)
return data
def send_to_device(tensor, device):
"""
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to send to a given device.
device (`torch.device`):
The device to send the data to.
Returns:
The same data structure as `tensor` with all tensors sent to the proper device.
"""
def _send_to_device(t, device):
return t.to(device)
def _has_to_method(t):
return hasattr(t, "to")
return recursively_apply(_send_to_device, tensor, device, test_type=_has_to_method)
def get_data_structure(data):
"""
Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors.
Args:
data (nested list/tuple/dictionary of `torch.Tensor`):
The data to send to analyze.
Returns:
The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors.
"""
def _get_data_structure(tensor):
return TensorInformation(shape=tensor.shape, dtype=tensor.dtype)
return recursively_apply(_get_data_structure, data)
def initialize_tensors(data_structure):
"""
Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`].
Returns:
The same data structure as `data` with tensors instead of [`~utils.TensorInformation`].
"""
def _initialize_tensor(tensor_info):
return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype)
return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
def find_batch_size(data):
"""
Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
Args:
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
Returns:
`int`: The batch size.
"""
if isinstance(data, (tuple, list)):
return find_batch_size(data[0])
elif isinstance(data, Mapping):
for k in data.keys():
return find_batch_size(data[k])
elif not isinstance(data, torch.Tensor):
raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
return data.shape[0]
def _tpu_gather(tensor, name="gather tensor"):
if isinstance(tensor, (list, tuple)):
return honor_type(tensor, (_tpu_gather(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
elif isinstance(tensor, Mapping):
return type(tensor)({k: _tpu_gather(v, name=f"{name}_{k}") for k, v in tensor.items()})
elif not isinstance(tensor, torch.Tensor):
raise TypeError(f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors.")
if tensor.ndim == 0:
tensor = tensor.clone()[None]
return xm.mesh_reduce(name, tensor, torch.cat)
def _gpu_gather(tensor):
def _gpu_gather_one(tensor):
if tensor.ndim == 0:
tensor = tensor.clone()[None]
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)
return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
_cpu_gather = _gpu_gather
def gather(tensor):
"""
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
Returns:
The same data structure as `tensor` with all tensors sent to the proper device.
"""
if AcceleratorState().distributed_type == DistributedType.TPU:
return _tpu_gather(tensor, name="accelerate.utils.gather")
elif AcceleratorState().distributed_type in [
DistributedType.DEEPSPEED,
DistributedType.MULTI_GPU,
DistributedType.FSDP,
]:
return _gpu_gather(tensor)
elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
return _cpu_gather(tensor)
else:
return tensor
def _gpu_gather_object(object: Any):
def _gpu_gather_object_one(object: Any):
output_objects = [None for _ in range(AcceleratorState().num_processes)]
torch.distributed.all_gather_object(output_objects, object)
return output_objects
return recursively_apply(_gpu_gather_object_one, object)
_cpu_gather_object = _gpu_gather_object
def gather_object(object: Any):
"""
Recursively gather object in a nested list/tuple/dictionary of objects from all devices.
Args:
object (nested list/tuple/dictionary of picklable object):
The data to gather.
Returns:
The same data structure as `object` with all the objects sent to every device.
"""
if AcceleratorState().distributed_type == DistributedType.TPU:
raise NotImplementedError("gather objects in TPU is not supported")
elif AcceleratorState().distributed_type in [
DistributedType.DEEPSPEED,
DistributedType.MULTI_GPU,
DistributedType.FSDP,
]:
return _gpu_gather_object(object)
elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
return _cpu_gather_object(object)
else:
return object
def _gpu_broadcast(data, src=0):
def _gpu_broadcast_one(tensor, src=0):
torch.distributed.broadcast(tensor, src=src)
return tensor
return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src)
def _tpu_broadcast(tensor, src=0, name="broadcast tensor"):
if isinstance(tensor, (list, tuple)):
return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor)))
elif isinstance(tensor, Mapping):
return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()})
return xm.mesh_reduce(name, tensor, lambda x: x[src])
def broadcast(tensor, from_process: int = 0):
"""
Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data
Returns:
The same data structure as `tensor` with all tensors broadcasted to the proper device.
"""
if AcceleratorState().distributed_type == DistributedType.TPU:
return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast")
elif AcceleratorState().distributed_type in [
DistributedType.DEEPSPEED,
DistributedType.MULTI_GPU,
DistributedType.FSDP,
]:
return _gpu_broadcast(tensor, src=from_process)
elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
return _gpu_broadcast(tensor, src=from_process)
else:
return tensor
def broadcast_object_list(object_list, from_process: int = 0):
"""
Broadcast a list of picklable objects form one process to the others.
Args:
object_list (list of picklable objects):
The list of objects to broadcast. This list will be modified inplace.
from_process (`int`, *optional*, defaults to 0):
The process from which to send the data.
Returns:
The same list containing the objects from process 0.
"""
if AcceleratorState().distributed_type == DistributedType.TPU:
for i, obj in enumerate(object_list):
object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process])
elif AcceleratorState().distributed_type in [
DistributedType.DEEPSPEED,
DistributedType.MULTI_GPU,
DistributedType.FSDP,
]:
torch.distributed.broadcast_object_list(object_list, src=from_process)
elif AcceleratorState().distributed_type == DistributedType.MULTI_CPU:
torch.distributed.broadcast_object_list(object_list, src=from_process)
return object_list
def slice_tensors(data, tensor_slice):
"""
Recursively takes a slice in a nested list/tuple/dictionary of tensors.
Args:
data (nested list/tuple/dictionary of `torch.Tensor`):
The data to slice.
tensor_slice (`slice`):
The slice to take.
Returns:
The same data structure as `data` with all the tensors slices.
"""
def _slice_tensor(tensor, tensor_slice):
return tensor[tensor_slice]
return recursively_apply(_slice_tensor, data, tensor_slice)
def concatenate(data, dim=0):
"""
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
Args:
data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`):
The data to concatenate.
dim (`int`, *optional*, defaults to 0):
The dimension on which to concatenate.
Returns:
The same data structure as `data` with all the tensors concatenated.
"""
if isinstance(data[0], (tuple, list)):
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
elif isinstance(data[0], Mapping):
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
elif not isinstance(data[0], torch.Tensor):
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
return torch.cat(data, dim=dim)
def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
"""
Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they
can safely be gathered.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
dim (`int`, *optional*, defaults to 0):
The dimension on which to pad.
pad_index (`int`, *optional*, defaults to 0):
The value with which to pad.
pad_first (`bool`, *optional*, defaults to `False`):
Whether to pad at the beginning or the end.
"""
def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False):
if dim >= len(tensor.shape):
return tensor
# Gather all sizes
size = torch.tensor(tensor.shape, device=tensor.device)[None]
sizes = gather(size).cpu()
# Then pad to the maximum size
max_size = max(s[dim] for s in sizes)
if max_size == tensor.shape[dim]:
return tensor
old_size = tensor.shape
new_size = list(old_size)
new_size[dim] = max_size
new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
if pad_first:
indices = tuple(
slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size))
)
else:
indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size)))
new_tensor[indices] = tensor
return new_tensor
return recursively_apply(
_pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first
)
def reduce(tensor, reduction="mean"):
"""
Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the
mean of a given operation.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to reduce.
reduction (`str`, *optional*, defaults to `"mean"`):
A reduction method. Can be of "mean", "sum", or "none"
Returns:
The same data structure as `data` with all the tensors reduced.
"""
def _reduce_across_processes(tensor, reduction="mean"):
state = AcceleratorState()
cloned_tensor = tensor.clone()
if state.distributed_type == DistributedType.TPU:
xm.all_reduce("sum", cloned_tensor)
return cloned_tensor
elif state.distributed_type in [
DistributedType.DEEPSPEED,
DistributedType.MULTI_GPU,
DistributedType.FSDP,
]:
torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM)
return cloned_tensor
else:
if reduction == "sum":
return cloned_tensor.sum()
else:
return cloned_tensor.mean()
return recursively_apply(_reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction)
def convert_to_fp32(tensor):
"""
Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to convert from FP16/BF16 to FP32.
Returns:
The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
"""
def _convert_to_fp32(tensor):
return tensor.float()
def _is_fp16_bf16_tensor(tensor):
return hasattr(tensor, "dtype") and (
tensor.dtype == torch.float16 or (is_torch_version(">=", "1.10") and tensor.dtype == torch.bfloat16)
)
return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
class ConvertOutputsToFp32:
"""
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
precision will be convert back to FP32.
Use a class instead of a decorator because otherwise, the prepared model can no longer be pickled (issue #273).
Args:
model_forward (`Callable`):
The function which outputs we want to treat.
Returns:
The same function as `model_forward` but with converted outputs.
"""
def __init__(self, model_forward):
self.model_forward = model_forward
update_wrapper(self, model_forward)
def __call__(self, *args, **kwargs):
return convert_to_fp32(self.model_forward(*args, **kwargs))
convert_outputs_to_fp32 = ConvertOutputsToFp32
def find_device(data):
"""
Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
Args:
(nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
"""
if isinstance(data, Mapping):
for obj in data.values():
device = find_device(obj)
if device is not None:
return device
elif isinstance(data, (tuple, list)):
for obj in data:
device = find_device(obj)
if device is not None:
return device
elif isinstance(data, torch.Tensor):
return data.device