Skip to content

Commit

Permalink
[FSDP] Add an arg for FSDP __init__ (#926)
Browse files Browse the repository at this point in the history
* [FSDP] Add an arg for FSDP __init__

Add an arg, disable_reshard_on_root, for FSDP __init__ to handle the following issue
#878
For some cases (models wrapped by autowrap), the parameters (of root modules) needs to be sharded, and reshard_after_forward should not be set to False.
"disable_reshard_on_root" is for users to choose whether to force reshard_after_forward of root modules to be False or not.

* Update fully_sharded_data_parallel.py

Modified the description of the feature to explain more clear.

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Update the comments for disable_reshard_on_root

Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>

* Modified the comments

Modified the comments of disable_reshard_on_root

Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
  • Loading branch information
foreveronehundred and min-xu-ai committed Feb 8, 2022
1 parent 7202115 commit 67bf5bf
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ class FullyShardedDataParallel(nn.Module):
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
disable_reshard_on_root (bool, Optional):
If ``True``, ``reshard_after_forward`` will be set to ``False`` if the module is a
FSDP root module to improve performance. For some cases, we do not reshard the full
parameters of an FSDP root module since those parameters are needed immediately for the
backward pass.
If ``False``, the performance will be lower, but it is needed because it helps to
save memory. Consider a case that an FSDP root module is a submodule of a model.
Backward pass may not start immediate after the FSDP root module finishes its forward.
So, reshard the parameters for the FSDP root modules can help to save memory in this case.
Default: True.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded)
Expand Down Expand Up @@ -303,6 +313,7 @@ def __init__(
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True,
disable_reshard_on_root: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
Expand Down Expand Up @@ -365,6 +376,7 @@ def __init__(
"parameter uses all the available ranks for the optimized performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
Expand Down Expand Up @@ -1150,10 +1162,11 @@ def _lazy_init(self) -> None:
# applies recursively, we only call this from the root instance.
self._cast_buffers()

# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False
if self.disable_reshard_on_root:
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False

# Due to the use of streams, we need to make sure the previous
# ``optim.step()`` is done before we all-gather parameters.
Expand Down

0 comments on commit 67bf5bf

Please sign in to comment.