-
Notifications
You must be signed in to change notification settings - Fork 260
/
fully_sharded_data_parallel.py
1562 lines (1338 loc) · 71.6 KB
/
fully_sharded_data_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
from enum import Enum, auto
import functools
from math import inf
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from . import fsdp_optim_utils as ou
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
class TrainingState(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state.
..note::
BACKWARD_PRE and BACKWARD_POST states are used to ensure we
receives backward hooks in the correct order. It is used to catch
unexpected order of hooks being called (likely due to our
hook registration logic or autograd engine logic changes).
TODO (Min): It would be nice to capture the stepping state as well.
Maybe we can use the model.zero_grad() call, but not sure if it
is called if optim.zero_grad() is used instead.
It would be nice to have clear state transition be explicit like:
zero_grad -> fwd -> bwd -> optionally accum grad by repeating
fwd/bwd -> stepping -> loop back to zero_grad
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Usage::
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce GPU memory usage, reduce system memory usage when initializing large
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
# Separately Wraps children modules with more than 1e8 params
large_tfmr = torch.nn.Transformer(d_model=2048, encoder_layers=12, decoder_layers=12)
self.l2 = auto_wrap(large_tfmr, min_num_params=1e8)
assert isinstance(self.l2, FSDP)
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
If you wrap every parameter inside a nested FSDP and leaving the outer
FSDP empty without any parameter, checkpointing activation may trigger
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.
Args:
module (nn.Module):
module to checkpoint
process_group (Optional):
process group for sharding
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded)
master copy of the model weights will be maintained in FP32.
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter gradients in FP32. This is only
relevant when *``mixed_precision``* is ``True``.
flatten_parameters (bool, Optional):
if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed.
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
*``cpu_offload``*.
bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that gradient reduction can
potentially overlap with backward computation. bucket_cap_mb
controls the bucket size in MegaBytes (MB). Buckets are sub-divided
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
"""
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
cpu_offload: bool = False,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
self.reshard_after_forward = reshard_after_forward
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.gradient_predivide_factor = self.get_gradient_predivide_factor(self.world_size)
self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")
validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
# shard any leftover parameters.
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
self._has_params = len(params) > 0
if not self._has_params:
self.flatten_parameters = False
if self.flatten_parameters:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.params = [self._fsdp_wrapped_module.flat_param]
else:
self._fsdp_wrapped_module = module
self.params = params
# Shard module parameters in place
self._shard_parameters_()
# Make sure all parameters are sharded.
for n, p in self.named_parameters():
assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"
self._reset_lazy_init()
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True
def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor
@property
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self.assert_state(TrainingState.IDLE)
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value
def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested FSDP instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, FullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None] """
return [p for p in self.parameters() if p.grad is not None]
@torch.no_grad()
def clip_grad_norm_(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
# filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
handles the partitioning and multiple devices per rank under the
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
# We don't call torch.cuda.synchronize() here, since clipping can be
# inside the train loop and we probably don't want to force a GPU-CPU sync.
# _lazy_init should be sufficient, since it will force the other streams
# to sync with the default stream (via _wait_for_previous_optim_step).
self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
self.assert_state(TrainingState.IDLE)
max_norm = float(max_norm)
norm_type = float(norm_type)
params_with_grad = self.params_with_grad
if not self.children_share_process_group:
raise NotImplementedError(
"clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
)
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
if norm_type == inf:
total_norm = local_norm
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
else:
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.process_group)
total_norm = total_norm ** (1.0 / norm_type)
if self.move_grads_to_cpu:
total_norm = total_norm.cpu()
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
# multiply by clip_coef
for p in params_with_grad:
p.grad.detach().mul_(clip_coef.to(p.grad.device)) # type: ignore
return total_norm
@torch.no_grad()
def _shard_parameters_(self) -> None:
"""
At initialization we wrap a module with full parameters and shard the
parameters in-place. Sharding is implemented by viewing each parameter
as a 1D Tensor and retaining only a single slice, where the slice size
is determined by the number of data parallel workers.
Wrapping modules with many small parameters (or with a very large data
parallel world size) will result in many small parameter shards and slow
performance. In this case it's better to set *``flatten_parameters``* to
``True``, so that all of the small parameters in the module are combined
into a single contiguous Tensor and sharded once.
After this initial sharding is complete, the user can initialize a
``torch.optim.Optimizer`` in the usual way, i.e.::
.. code-block:: python
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
The optimizer will see only a single slice of parameters and will thus
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.numel_padded_per_param = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
if self.mixed_precision:
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._orig_size = p.data.size()
if not p._is_sharded:
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard, num_to_pad
def extra_repr(self) -> str:
return (
f"rank={self.rank}, world_size={self.world_size}, "
f"reshard_after_forward={self.reshard_after_forward}, "
f"mixed_precision={self.mixed_precision}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def __getstate__(self) -> Dict[str, str]:
"""Serialize the state of the current FullyShardedDataParallel instance.
Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`.
"""
state = copy.copy(self.__dict__)
state["is_sharded"] = [p._is_sharded for p in self.params]
state["orig_sizes"] = [p._orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # process_group isn't pickleable
self._reset_lazy_init()
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Intercept state setting and perform needed changes on params."""
super().__setstate__(state)
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
p._is_sharded = is_sharded
p._orig_size = size
return p
self.params = [
fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
]
del self.is_sharded
del self.orig_sizes
self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function.
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]": # type: ignore
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._cast_buffers(dtype=torch.float32)
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(volatile=True):
state_dict = super().state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
else:
if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
if self.cpu_offload:
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to buffer_dtype.
self._cast_buffers()
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
def local_state_dict(self, *args, **kwargs): # type: ignore
"""
Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
"""
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
return self.state_dict(*args, **kwargs)
@contextlib.contextmanager
def _no_return_full_state_dict(self) -> Generator:
backup = self._return_full_state_dict
self._return_full_state_dict = False
try:
yield
finally:
self._return_full_state_dict = backup
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if self._return_full_state_dict:
with self.summon_full_params():
return self.module.load_state_dict(state_dict, strict)
else:
torch.cuda.synchronize()
self._lazy_init()
return self.module.load_state_dict(state_dict, strict)
def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a local (sharded) state_dict."""
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
output = self.load_state_dict(state_dict, strict)
return output
@contextlib.contextmanager
def no_sync(self) -> Generator:
"""
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
"""
self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported"
self.assert_state(TrainingState.IDLE)
# This instance may wrap other FullyShardedDataParallel instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
m._require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
"""
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
.. note:: This can be used on inner FSDPs.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
with contextlib.ExitStack() as stack:
# Summon all params for any nested FSDP instances.
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module.summon_full_params(recurse=False, volatile=volatile))
# Yield to the caller, with full params in all nested instances.
yield
# Exiting from the ExitStack will re-shard params.
return
else:
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None
with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We
# use self.params[0] instead of full_tensors since the
# latter may contain padding.
assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
try:
yield
finally:
stack.close()
assert len(full_tensors) == len(self.params)
for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes
def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
before the first forward pass.
"""
# Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__.
for p in self.params:
self._init_param_attributes(p)
# Initialize _is_root and setup streams. These steps would ideally
# happen in __init__, but _is_root can only be determined after the
# entire model hierarchy is setup, thus we run it lazily.
if self._is_root is None:
self._set_is_root()
self._setup_streams()
if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self._cast_buffers()
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False
# Due to the use of streams, we need to make sure the previous
# ``optim.step()`` is done before we all-gather parameters.
self._wait_for_previous_optim_step()
@torch.no_grad()
def _init_param_attributes(self, p: Parameter) -> None:
"""
We manage several attributes on each Parameter instance. The first two
are set by :func:`_shard_parameters_`:
``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``_orig_size``: the size of the original Parameter (before sharding)
The remaining attributes are set here:
``_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``cpu_offload``*.
``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be
a single shard of the parameters in FP16, used for all-gather.
``_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
"""
assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
if hasattr(p, "_fp32_shard"):
return
# A single shard of the parameters in full precision.
p._fp32_shard = p.data
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.cpu_offload:
assert p._fp32_shard.device == torch.device("cpu")
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)
if self.move_grads_to_cpu:
# We can optionally move the grad shard to CPU during the backward
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None:
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
self._is_root = True
assert self._queue_wait_for_post_backward_closure is None
self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
assert m._is_root is None
m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group:
self.children_share_process_group = False
def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
return
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams["fp32_to_fp16"] = torch.cuda.Stream()
# Stream for all-gathering parameters.
self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream()
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
# We share streams with all children instances, which allows them to
# overlap transfers across the forward pass without synchronizing with
# the default stream.
for n, m in self.named_modules():
if n != "" and isinstance(m, FullyShardedDataParallel):
m._streams = self._streams
m._reducer = self._reducer
def _wait_for_previous_optim_step(self) -> None:
"""
The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
if self.mixed_precision:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._lazy_init()
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
if self._is_root and self.mixed_precision:
args, kwargs = cast_inputs_to_fp16(*args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
outputs = self.module(*args, **kwargs)
if self.reshard_after_forward:
self._free_full_params()
if self.mixed_precision:
self._free_fp16_param_shard()
# Switch to main FP32 param shard. We maintain this invariant throughout
# the code, i.e., ``p.data == p._fp32_shard`` after each function. This
# also ensures that after the first forward, the optimizer state will be
# initialized with the correct dtype and (sharded) size, since optimizer
# state is typically initialized lazily in ``optim.step()``.
self._use_fp32_param_shard()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if needed).
outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass.
self.training_state = TrainingState.IDLE
return outputs
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward."""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
pre_backward_hook_has_run = [False]
def _pre_backward_hook(*unused: Any) -> None:
if pre_backward_hook_has_run[0]:
return # only run once
pre_backward_hook_has_run[0] = True
# Start of a backward pass.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
self.training_state = TrainingState.BACKWARD_PRE
# All-gather full parameters.
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
# Make sure p.grad has the correct size/device (or set it to None).
self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad:
t.register_hook(_pre_backward_hook)
return t
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(_register_hook, outputs)
return outputs
def _register_post_backward_hooks(self) -> None:
"""
Register backward hooks to reshard params and reduce-scatter grads.
This is called during forward pass. The goal is to attach a hook
on each of the parameter's gradient generating function (``grad_acc``
below) so that the hook is called *after* all gradients for that
param are computed.
Goals:
1. We want the hook to fire once and only once *after* all gradients
are accumulated for a param.
2. If it fires more than once, we end up incorrectly shard the grad
multiple times. (could lead to dimension too small)
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has this field
# defined. Accidentally accessing this field will assert on all
# other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False