Skip to content

Commit

Permalink
[fix] better handling non-flatten in FSDP (#1072)
Browse files Browse the repository at this point in the history
* [fix] better handling non-flatten in FSDP

- see the detailed comment about that backward firing case
- also minor debugging help in FSDP
- also minor fix in FPW's state dict

* [feat] disallow reset_parameters by default

* [feat] adding fsdp_instances API - useful in check wrapping by user code

* [fix] one line fix but more than a day of debugging

* fixed the case of loading combined check with empty fsdp instances

* fixed another bug around state loading the root/nonroot module full param caching due to not resharding after forward

* [feat] support .half and .float better

* fixed a bug in gather optim state losses extra keys from the original state_dict

* fixed a test failure in mixed precision

* fixed another bug affecting no_sync grad acc

* fixed a bug and a test in fsdp optim state

* fixed another corner case

* added a comment

* skip ssd offload tests

* skip fsdp one for ssd overload

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Sep 23, 2022
1 parent 47ce21a commit 429f3d3
Show file tree
Hide file tree
Showing 13 changed files with 332 additions and 134 deletions.
2 changes: 1 addition & 1 deletion fairscale/experimental/tooling/layer_memory_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def _is_same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
Indicate if x and y share the same storage, meaning that one of them
is a view, reshape or stride of the other or from a common tensor
"""
return x.storage().data_ptr() == y.storage().data_ptr() # type: ignore
return x.storage().data_ptr() == y.storage().data_ptr()

@staticmethod
def _collect_tensors(module_io_tensors: Union[torch.Tensor, Sequence[torch.Tensor]]) -> List[torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions fairscale/nn/data_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
OffloadConfig,
TrainingState,
auto_wrap_bn,
get_fsdp_instances,
no_pre_load_state_dict_hook,
)

Expand Down
41 changes: 21 additions & 20 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel

# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}

# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
Expand Down Expand Up @@ -52,20 +49,24 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_state[local_id].update(singleton_state[local_id])
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in UNFLAT_RETURN_KEYS:
new_sd[k] = copy.deepcopy(sd[k])

# Now make a new param_groups copy and update it.
new_sd_pg = copy.deepcopy(sd["param_groups"])
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# The values() list may look like [0,0,None,None,2,2]. We use
# groupby to remove the duplicates and then count the length of
# resulting iter.
num_local_params = sum(1 for _ in groupby(param_id_map.values()))
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))
new_sd_pg[pg_id]["params"] = list(range(num_local_params))

return new_sd
# update the original sd so that we don't lose extra keys, like loss_scale.
sd["state"] = new_state
sd["param_groups"] = new_sd_pg
# delete extra keys we have added to match the original state.
del sd["uncollected_local_ids"]
del sd["param_id_map"]
return sd


def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
Expand Down Expand Up @@ -202,7 +203,7 @@ def build_unflat_state_dict(
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
original_sd: Dict,
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts
from each rank. This is only called on rank 0.
Expand All @@ -213,7 +214,7 @@ def build_unflat_state_dict(
state: all-gathered combined/local/flatten state_dict
singleton_state: all-gathered singleton_state (dimensionless tensors)
uncollected_opt_state: non-tensor and not-gathered state
param_groups: the original rank 0's sd["param_groups"]
original_sd: the original rank 0's sd
Returns:
dict: an unflattened, nonsharded optimizer state, as if FSDP was not there.
Expand All @@ -228,19 +229,19 @@ def build_unflat_state_dict(
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)

# Since there are no tensors in param_groups, deepcopy is fine.
param_groups = copy.deepcopy(param_groups)
param_groups = copy.deepcopy(original_sd["param_groups"])
# Casting needed only for mypy.
num_params = sum([cast(int, m.num_params_managed) for m in instance_list])
param_groups[0]["params"] = list(range(num_params))
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS
return unflat_optim_state_dict

# Update the original sd so we don't loss extra state like loss_scale.
original_sd["state"] = dict(sorted(unflat_state.items())) # NOTE: this is probably already sorted
original_sd["param_id_map"] = global_to_local_id
original_sd["param_groups"] = param_groups
original_sd["uncollected_local_ids"] = list(uncollected_opt_state.keys())
return original_sd


def is_singleton_tensor(x: Any) -> bool:
Expand Down

0 comments on commit 429f3d3

Please sign in to comment.