Skip to content

Commit

Permalink
FullyShardedDataParallel: only return full state dict on rank 0 (#885)
Browse files Browse the repository at this point in the history
* FullyShardedDataParallel: only return full state dict on rank 0

* Add flag and make rank 0 only optional

* Add tests

* Add docs

* address comments

* update comments

* update torch nightly version

* update torchvision number for torch nightly dependence

* add changelog

* Update CHANGELOG.md

* Update CHANGELOG.md
  • Loading branch information
four4fish committed Jan 6, 2022
1 parent c5e471b commit d3417ce
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing
pip install --progress-bar off --pre torch==1.11.0.dev20211101+cu111 torchvision==0.12.0.dev20211101+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off --pre torch==1.11.0.dev20211231+cu111 torchvision==0.12.0.dev20211231+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.5] - TBD

### Added
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]

### Changed

Expand Down
22 changes: 20 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ class FullyShardedDataParallel(nn.Module):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
"""

def __init__(
Expand All @@ -302,6 +307,7 @@ def __init__(
verbose: bool = False,
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
):
init_start = time.time()
super().__init__()
Expand All @@ -324,6 +330,7 @@ def __init__(
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False

Expand Down Expand Up @@ -418,7 +425,7 @@ def __init__(

# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)

# Flag to indicate whether state_dict() should automatically summon the
Expand Down Expand Up @@ -2353,8 +2360,19 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:


def _post_state_dict_hook(
module: FullyShardedDataParallel, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
state_dict_on_rank_0_only: bool,
module: FullyShardedDataParallel,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
*args: Any,
) -> "OrderedDict[str, torch.Tensor]":
# When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
# returns full state dict on rank 0 and return empty dict non-rank 0,
# which allow FullyShardedDataParallel to skip the GPU -> CPU copy on
# non-rank 0 altogether and prevent OOM.
if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear()
return state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
Expand Down
14 changes: 14 additions & 0 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def _test_identical_outputs(
shard_loss = shard_loss.cuda()
shard_state_dict = model.state_dict()

if config.get("state_dict_on_rank_0_only", False):
if torch.distributed.get_rank() != 0:
assert shard_state_dict == {}
# rank 0 shard_state_dict test covered in the following test.
# return is needed here, because with state_dict_on_rank_0_only=True, the following assert will fail on rank!=0
return

try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
Expand Down Expand Up @@ -361,6 +368,13 @@ def test_delayed_reduce_scatter(self):
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)

@parameterized.expand([[True], [False]], name_func=rename_test)
def test_state_dict_on_rank_0_only(self, state_dict_on_rank_0_only):
config = {"state_dict_on_rank_0_only": state_dict_on_rank_0_only}
model_fn = functools.partial(TransformerWithSharedParams)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)

@parameterized.expand([[{"checkpoint_act": False}], [{"checkpoint_act": True}]], name_func=rename_test)
def test_mixture_of_experts(self, moe_config):
fsdp_config = {"mixed_precision": True}
Expand Down

0 comments on commit d3417ce

Please sign in to comment.