Skip to content

Commit

Permalink
Avoid calling _free_fp16_param_shard() too early with PR 1159
Browse files Browse the repository at this point in the history
  • Loading branch information
jiecaoyu committed Feb 21, 2024
1 parent a4f02ef commit f2bb56f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Free full params.
self._free_full_params([param])

if self.mixed_precision:
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None:
if self._should_free_in_backward():
self._free_full_params([param])
if self.mixed_precision:
if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward):
self._free_fp16_param_shard([param])
self._use_fp32_param_shard([param])

Expand Down Expand Up @@ -1937,7 +1937,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# For the 1st layer, if the forward inputs did not require
# gradient, then we cannot run a reshard hook for it, and
# we instead free here.
if p._full_param_padded.untyped_storage().size() > 0:
if p._is_sharded and p._full_param_padded.untyped_storage().size() > 0:
fsdp_module._post_backward_reshard_hook(p)
continue

Expand Down

0 comments on commit f2bb56f

Please sign in to comment.