diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index d9b20fca7..7d7fce9a4 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1733,7 +1733,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Free full params. self._free_full_params([param]) - if self.mixed_precision: + if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward): # This is a no-op if reshard_after_forward is True, since we already # free the param shard when rebuilding the full params in the # pre_backward_hook. @@ -1861,7 +1861,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None: if self._should_free_in_backward(): self._free_full_params([param]) - if self.mixed_precision: + if self.mixed_precision and (self._require_backward_grad_sync or self.reshard_after_forward): self._free_fp16_param_shard([param]) self._use_fp32_param_shard([param]) @@ -1937,7 +1937,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: # For the 1st layer, if the forward inputs did not require # gradient, then we cannot run a reshard hook for it, and # we instead free here. - if p._full_param_padded.untyped_storage().size() > 0: + if p._is_sharded and p._full_param_padded.untyped_storage().size() > 0: fsdp_module._post_backward_reshard_hook(p) continue