-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix fsdp+pp+te WPS decreasing issue #1139
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -557,6 +557,7 @@ def __init__( | |
self.dont_wait_current_stream_for_post_all_gather = False | ||
self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None | ||
self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None | ||
self._module_fqn = None | ||
|
||
def _get_gradient_predivide_factor(self, world_size: int) -> float: | ||
factor: int = 1 | ||
|
@@ -1220,6 +1221,9 @@ def _lazy_init(self) -> None: | |
self._set_is_root() | ||
self._setup_streams() | ||
self._setup_output_hook_list() | ||
for module_name, module in self.named_modules(): | ||
if isinstance(module, FullyShardedDataParallel): | ||
module._module_fqn = module_name | ||
|
||
if self._is_root: | ||
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers | ||
|
@@ -1650,6 +1654,9 @@ def _register_post_backward_hooks(self) -> None: | |
assert p_tmp.grad_fn is not None | ||
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. | ||
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) | ||
if not hasattr(p, "_shard_bwd_hooks"): | ||
p._shard_bwd_hooks = [] | ||
p._shard_bwd_hooks.append((grad_acc, handle)) | ||
p._shard_bwd_hook = (grad_acc, handle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be deleted? See P841842878 CC @awgu There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commented this line following the style in this file. |
||
|
||
@torch.no_grad() | ||
|
@@ -1710,6 +1717,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
|
||
# Switch to FP32 shard after backward. | ||
self._use_fp32_param_shard([param]) | ||
if self.mixed_precision and self.fp32_reduce_scatter: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry missed this comment before. Wonder if we can have a separate commit for main_grad related changes from the changes for wps decrease fix (P841842878). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if getattr(param, "main_grad", None) is None: | ||
param.main_grad = param.grad.to(torch.float32) | ||
else: | ||
param.main_grad.add_(param.grad.data) | ||
|
||
param.grad = None | ||
|
||
if not self._require_backward_grad_sync: | ||
return | ||
|
@@ -1718,23 +1732,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
# reductions in post_backward stream. | ||
self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) | ||
with torch.cuda.stream(self._streams["post_backward"]): | ||
orig_grad_data = param.grad.data | ||
|
||
if self.fp32_reduce_scatter: | ||
# Cast grad to FP32. | ||
param.grad.data = param.grad.data.float() | ||
orig_grad_data = param.grad.data.float() | ||
else: | ||
orig_grad_data = param.grad.data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should keep param.grad.data = param.grad.data.float() so something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
|
||
if self.gradient_predivide_factor > 1: | ||
# Average grad by world_size for consistency with PyTorch DDP. | ||
param.grad.data.div_(self.gradient_predivide_factor) | ||
if getattr(param, "main_grad", None) is not None: | ||
param.main_grad.data.div_(self.gradient_predivide_factor) | ||
else: | ||
param.grad.data.div_(self.gradient_predivide_factor) | ||
|
||
if param._is_sharded: | ||
assert self._reducer is not None | ||
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into | ||
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple | ||
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't | ||
# matter, neglecting rounding. | ||
grad = param.grad.data | ||
if getattr(param, "main_grad", None) is not None: | ||
grad = param.main_grad.data | ||
param.main_grad = None | ||
else: | ||
grad = param.grad.data | ||
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. | ||
# | ||
# The effect on memory consumption is not usually significant. No extra memory is allocated if this | ||
|
@@ -1860,6 +1882,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: | |
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}") | ||
# p._shard_bwd_hook[1].remove() | ||
# delattr(p, "_shard_bwd_hook") | ||
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync: | ||
for _, handle in p._shard_bwd_hooks: | ||
handle.remove() | ||
|
||
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad | ||
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard | ||
|
@@ -1876,7 +1901,10 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: | |
p.device == p._saved_grad_shard.device, | ||
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}", | ||
) | ||
p.grad = p._saved_grad_shard | ||
if p._saved_grad_shard.dtype != p.dtype: | ||
p.grad = p._saved_grad_shard.to(p.dtype) | ||
else: | ||
p.grad = p._saved_grad_shard | ||
|
||
if hasattr(p, "_saved_grad_shard"): | ||
delattr(p, "_saved_grad_shard") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These only needed for debugging (when _FSDP_DEBUG is set in P841842878)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted.