Skip to content

Commit

Permalink
[fix] FSDP: EMA related fixes (#922)
Browse files Browse the repository at this point in the history
* add an ignore file

* [fix] FSDP: handle the lazy_init better

- when state_dict and load_state_dict is called, let'em not change
  the lazy_init state.

* changelog

* longer timeout

* Revert "longer timeout"

This reverts commit 00cc145.

* testing

* adding the failed test

* fix the global to local id

* formatting

* more complete fix and test

* minor fix for an assert

* update changelog

* remove an extra line

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py

Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py

Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py

Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>

* addressed review comments

Co-authored-by: Min Xu <min.xu.public@gmail.com>
Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 3, 2022
1 parent 2ca4f0e commit 9f347f3
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ test-results/
# Coverage reports
.coverage
.coverage.*
./coverage.xml

# Environments
.env
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
left after children were wrapped. [#930]
- FSDP: Add support for saving optimizer state when using expert replicas with FSDP.

### Fixed
- FSDP: fixed handling of internal states with state_dict and load_state_dict
function so that they don't change lazy init state if training hasn't started. [#922]
- FSDP: added support of optimizer state handling when some of the parameters are
not used. An example is that in a model with a EMA copy that doesn't get trained
but still wants to be sharded. [#922]

## [0.4.5] - 2022-01-14

### Added
Expand Down
68 changes: 60 additions & 8 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from itertools import groupby
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast

import torch
Expand All @@ -20,9 +21,12 @@
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
# Get a set of local ids, like {0, None, 2}, then we remove None from it.
local_ids = set(param_id_map.values())
if None in local_ids:
local_ids.remove(None)
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
new_state: Dict = {local_id: {} for local_id in local_ids}
singleton_state: Dict = copy.deepcopy(new_state)
else:
new_state = {}
Expand Down Expand Up @@ -55,7 +59,10 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:

# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# TODO: this list could be huge. Can we avoid materializing?
# The values() list may look like [0,0,None,None,2,2]. We use
# groupby to remove the duplicates and then count the length of
# resulting iter.
num_local_params = sum(1 for _ in groupby(param_id_map.values()))
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd
Expand Down Expand Up @@ -91,7 +98,18 @@ def _unflatten_optim_state(
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
"""Convert optimizer state for flattened parameters into original, unflatten ones."""
"""Convert optimizer state for flattened parameters into original, unflattened ones.
Args:
combined_state: all-gathered state with tensors
instance_list: list of FSDP wrapper object instances
world_pad_info: [param_id][fsdp_instance_id][bytes_padded_per_rank]
singleton_state: all-gathered dimensionless tensors
Returns:
state: unflattened state dict
idx_mapping: a mapping from global ID to local ID
"""
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
Expand All @@ -100,7 +118,7 @@ def _unflatten_optim_state(

# non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step".
# we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
non_tensor_state = {id: _extract_non_tensor_state(combined_state, id) for id in combined_state}

# Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy.
Expand All @@ -114,20 +132,41 @@ def _unflatten_optim_state(
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
# Some params could be unused, which means the optimizer
# hasn't created their state. Therefore, `local_id` obtained
# by enumerating the params above could be out of the range
# of keys in `combined_state` above. Here is an example:
#
# global local notes
# 0 0 FC1's weight, first flat buffer
# 1 0 FC1's bias, first flat buffer
# 2 None FC2's weight, no flat state
# 3 None FC2's bias, no flat state
# 4 2 FC3's weight, second flat buffer (but with id 2)
# 5 2 FC3's bias, second flat buffer (but with id 2)
global_to_local_id[next_global_id] = local_id if local_id in local_ids else None
next_global_id += 1
if not combined_state:
return {}, global_to_local_id

# copy non tensor state (like the "step" count) to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}

# remove the global entries that don't have optim state because pytorch
# optimizer's state_dict() function returns a state_dict without the missing
# param, so we shouldn't have things like "1:{}" for missing params.
for g, l in global_to_local_id.items():
if l is None:
del unflat_state[g]

if non_tensor_state[0].keys() == combined_state[0].keys():
# Early return if there is no tensors in the state dict.
return unflat_state, global_to_local_id

local_to_global: Dict[int, List] = {i: [] for i in local_ids}
for g, l in global_to_local_id.items():
local_to_global[l].append(g)
if l is not None:
local_to_global[l].append(g)
# loop over parameters in state.
# Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views
# get_views returns multiple tensors, each of which is a new parameter with a new "global" id.
Expand Down Expand Up @@ -165,7 +204,20 @@ def build_unflat_state_dict(
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts
from each rank. This is only called on rank 0.
Args:
instance_list: list of FSDP wrapper objects
world_pad_info: [param_id][fsdp_instance_id][bytes_padded_per_rank]
state: all-gathered combined/local/flatten state_dict
singleton_state: all-gathered singleton_state (dimensionless tensors)
uncollected_opt_state: non-tensor and not-gathered state
param_groups: the original rank 0's sd["param_groups"]
Returns:
dict: an unflattened, nonsharded optimizer state, as if FSDP was not there.
"""
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)

Expand Down
27 changes: 22 additions & 5 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ def extra_repr(self) -> str:
)
if self.verbose:
repr = (
f"self={id(self)} is_root={self._is_root}, "
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
Expand Down Expand Up @@ -907,6 +908,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
is_uninitialized = self._is_root is None # See comment below on why we use this.
self._lazy_init()

def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
Expand All @@ -931,6 +933,12 @@ def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:

# In case we are in mixed precision, restore buffers back to buffer_dtype.
maybe_cast_buffers()
# We shouldn't change the init state in case this was an inner module and
# users simply wanted to get state_dict before training.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return state_dict

@typing.overload
Expand Down Expand Up @@ -999,7 +1007,15 @@ def _load_state_dict(
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
return self._load_state_dict(state_dict, strict)
is_uninitialized = self._is_root is None # See comment below on why we use this.
sd = self._load_state_dict(state_dict, strict)
# We shouldn't change the init state in case this was an inner module and
# users simply wanted to load_state_dict before training.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return sd

def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
Expand Down Expand Up @@ -1297,7 +1313,7 @@ def _set_is_root(self) -> None:
if n != "" and isinstance(m, FullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root
assert m._is_root is None or not m._is_root, f"offending FSDP instance is {id(m)}, {m}"
if m._is_root is None:
m._is_root = False
if m.process_group != self.process_group:
Expand Down Expand Up @@ -2363,9 +2379,10 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
# Due to unused params, the length of the state can be anywhere between
# 0 and number of params/fsdp_instance.
assert len(full_optim_state_dict["state"]) <= len(
instance_list
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'

# get the portion of dict associated with the shard, in place
Expand Down
4 changes: 3 additions & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ def _init_flatten_params(
assert (
len(set(p.dtype for p in params)) == 1
), f"expects all parameters to have same dtype: fp32: {fp32_msg} \n fp16: {fp16_msg} "
assert len(set(p.requires_grad for p in params)) == 1, "expects all parameters to have same requires_grad"
assert (
len(set(p.requires_grad for p in params)) == 1
), f"expects all parameters to have same requires_grad {p_set}"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos

Expand Down
54 changes: 54 additions & 0 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from parameterized import parameterized
import torch
from torch import nn
from torch.optim import SGD, Adadelta, Adam # type: ignore

from fairscale.nn import FullyShardedDataParallel
Expand Down Expand Up @@ -225,6 +226,34 @@ def _test_consolidated_optimizer(
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)

@parameterized.expand(
[(True,), (False,)],
name_func=rename_test,
)
def test_model_with_unused_params(self, wrap_l2):
"""Test handling of model with unused params by gather_full_optim_state_dict()"""
test_fn = functools.partial(self._test_model_with_unused_params, wrap_l2=wrap_l2)
spawn_and_init(test_fn, world_sizes=[2])

@classmethod
def _test_model_with_unused_params(self, rank, pg, wrap_l2):
model = ModelWithUnusedParams(wrap_l2).cuda()
data = torch.rand(4).cuda().requires_grad_(True)
model = FullyShardedDataParallel(model)
optim = SGD(model.parameters(), momentum=0.9, lr=0.1)
out = model(data).sum()
out.backward()
optim.step()
model.zero_grad(set_to_none=True)
sd = model.gather_full_optim_state_dict(optim)
if rank == 0:
shard_sd = model.get_shard_from_optim_state_dict(sd)
orig_sd = optim.state_dict()
orig_sd = recursive_copy_to_device(orig_sd, non_blocking=False, device="cpu")
objects_are_equal(shard_sd, orig_sd, raise_exception=True)
else:
assert sd is None, sd

def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
group = DummyProcessGroup(0, 1)
Expand All @@ -234,8 +263,33 @@ def test_named_params_ordering(self):
assert objects_are_equal(p, named_pars[i])

def test_is_singleton_tensor(self):
"""Test is_singleton_tensor function"""
assert is_singleton_tensor(torch.tensor(4.0))
assert not is_singleton_tensor(torch.tensor([4.0]))
assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
assert not is_singleton_tensor([4.0])
assert not is_singleton_tensor(4.0)


class ModelWithUnusedParams(nn.Module):
def __init__(self, wrap_l2):
super().__init__()
self.l = nn.Linear(4, 4)
# unused param must be wrapped, otherwise, due to flatten, it
# is always used.
self.not_trained = nn.Linear(4, 4).requires_grad_(False)
self.not_trained = FullyShardedDataParallel(self.not_trained)
# optionally testing a used param after the unused one by
# wrapping it.
self.l2 = nn.Linear(4, 4)
if wrap_l2:
# When wrapping happens, the unused param will be in the middle
# of the param list (for optimizer state dict), not at the
# end. This way, we can test the handling code in more corner
# cases.
self.l2 = FullyShardedDataParallel(self.l2)

def forward(self, x):
with torch.no_grad():
y = self.not_trained(x)
return self.l2(self.l(x)) - y

0 comments on commit 9f347f3

Please sign in to comment.