Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size #565

Merged
merged 4 commits into from Apr 3, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Expand Up @@ -187,6 +187,7 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.gradient_predivide_factor = self.get_gradient_predivide_factor(self.world_size)
myleott marked this conversation as resolved.
Show resolved Hide resolved

self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
Expand Down Expand Up @@ -252,6 +253,12 @@ def __init__(
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True

def get_gradient_predivide_factor(self, world_size: int) -> int:
factor = 1
while world_size % factor == 0 and world_size / factor > factor:
factor = factor * 2
return factor

@property
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
Expand Down Expand Up @@ -1071,7 +1078,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.world_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to if self.gradient_predivide_factor > 1

# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.world_size)
param.grad.data.div_(self.gradient_predivide_factor)

callback_fn = functools.partial(self._post_reduction_hook, param)
if param._is_sharded:
Expand All @@ -1098,7 +1105,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None
self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad
param.grad.data = reduced_grad.div_(self.world_size/self.gradient_predivide_factor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to:

param.grad.data = reduced_grad
if self.gradient_postdivide_factor > 1:
    param.grad.data.div_(self.gradient_postdivide_factor)

# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
Expand Down