Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size #565

Merged
merged 4 commits into from Apr 3, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Expand Up @@ -187,6 +187,8 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor

self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
Expand Down Expand Up @@ -252,6 +254,12 @@ def __init__(
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True

def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor

@property
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
Expand Down Expand Up @@ -1069,9 +1077,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)

if self.world_size > 1:
if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)
param.grad.data.div_(self.gradient_predivide_factor)

callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded:
Expand Down Expand Up @@ -1099,6 +1107,9 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
Expand Down