Skip to content

Commit

Permalink
[bug] fix optim state gather when there is empty FSDP instances (#1071)
Browse files Browse the repository at this point in the history
* [bug] fix optim state gather when there is empty FSDP instances

* fixes an anssert and a test bug
  • Loading branch information
min-xu-ai committed Sep 13, 2022
1 parent 203dd66 commit d8fc94d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
49 changes: 32 additions & 17 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,11 +1343,10 @@ def _set_is_root(self) -> None:
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root, f"offending FSDP instance is {id(m)}, {m}"
if m._is_root is None:
m._is_root = False
# We set inner FSDP to non-root but they could have the value of True
# if, for example, a module is called first (like infernece, EMA)
# then later we call an outer FSDP for state dict load/save.
m._is_root = False
if m.process_group != self.process_group:
self.children_share_process_group = False

Expand Down Expand Up @@ -2277,9 +2276,11 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
raise ValueError(msg)

def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from each rank."""
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances()] from each rank."""
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances]
my_pad_info: List[List[int]] = [
cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances(skip_empty=True)
]
for rank in range(self.world_size):
if rank == self.rank:
pad_info = my_pad_info
Expand All @@ -2296,24 +2297,31 @@ def _gather_optim_state(
"""For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor

# Non-empty FSDP instance and sd_state item number must match.
fsdp_instances = self._fsdp_instances(skip_empty=True)
assert len(fsdp_instances) >= len(sd_state), f"{len(fsdp_instances)} vs. {len(sd_state)}"

for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
# For shared params, we are not flattening. We have only 1 non-shared
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
non_shared_params = fsdp_instances[k].non_shared_params()

# This is the world size and process group of the FSDP submodule which can be
# different than the parent module. For example, when FSDP is used with MoE.
non_shared_world_size = self._fsdp_instances[k].world_size
non_shared_process_group = self._fsdp_instances[k].process_group
non_shared_world_size = fsdp_instances[k].world_size
non_shared_process_group = fsdp_instances[k].process_group

assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)} FSDP={self}"
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors

for buffer_name, t in v.items():
if torch.is_tensor(t):
t = t.to(self.compute_device)
Expand Down Expand Up @@ -2370,16 +2378,23 @@ def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
self._fsdp_instances(skip_empty=True),
pad_info,
state,
singleton_state,
self.uncollected_opt_state,
sd["param_groups"],
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict

@property
def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
def _fsdp_instances(self, skip_empty: bool = False) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
result = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if skip_empty:
result = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, result))
return result

def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has
Expand All @@ -2396,7 +2411,7 @@ def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item()
# Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances()) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
Expand All @@ -2423,7 +2438,7 @@ def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any])
(dict): a shard of the optimizer state.
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
instance_list = self._fsdp_instances()
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def _test_consolidated_optimizer(
unwrapped_sd = optim_unwrapped.state_dict()

if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
no_broadcast_children = [x for x in fsdp._fsdp_instances() if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
assert fsdp._fsdp_instances()[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024**3
tstart = time()
Expand Down

0 comments on commit d8fc94d

Please sign in to comment.