Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fsdp+pp+te WPS decreasing issue #1139

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def __init__(
self.dont_wait_current_stream_for_post_all_gather = False
self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None
self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None
self._module_fqn = None

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
Expand Down Expand Up @@ -1220,6 +1221,9 @@ def _lazy_init(self) -> None:
self._set_is_root()
self._setup_streams()
self._setup_output_hook_list()
for module_name, module in self.named_modules():
if isinstance(module, FullyShardedDataParallel):
module._module_fqn = module_name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These only needed for debugging (when _FSDP_DEBUG is set in P841842878)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted.


if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
Expand Down Expand Up @@ -1650,6 +1654,9 @@ def _register_post_backward_hooks(self) -> None:
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
if not hasattr(p, "_shard_bwd_hooks"):
p._shard_bwd_hooks = []
p._shard_bwd_hooks.append((grad_acc, handle))
p._shard_bwd_hook = (grad_acc, handle)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted? See P841842878 CC @awgu

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented this line following the style in this file.


@torch.no_grad()
Expand Down Expand Up @@ -1710,6 +1717,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
if self.mixed_precision and self.fp32_reduce_scatter:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed this comment before. Wonder if we can have a separate commit for main_grad related changes from the changes for wps decrease fix (P841842878).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split into 2 PRs (#1139 and #1140)

if getattr(param, "main_grad", None) is None:
param.main_grad = param.grad.to(torch.float32)
else:
param.main_grad.add_(param.grad.data)

param.grad = None

if not self._require_backward_grad_sync:
return
Expand All @@ -1718,23 +1732,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data

if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()
orig_grad_data = param.grad.data.float()
else:
orig_grad_data = param.grad.data

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep param.grad.data = param.grad.data.float() so something like

            if self.fp32_reduce_scatter:
                # Cast grad to FP32.
                param.grad.data = param.grad.data.float()

            orig_grad_data = param.grad.data

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
if getattr(param, "main_grad", None) is not None:
param.main_grad.data.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if getattr(param, "main_grad", None) is not None:
grad = param.main_grad.data
param.main_grad = None
else:
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand Down Expand Up @@ -1860,6 +1882,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand All @@ -1876,7 +1901,10 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard
if p._saved_grad_shard.dtype != p.dtype:
p.grad = p._saved_grad_shard.to(p.dtype)
else:
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
Expand Down
Loading