Skip to content

Commit

Permalink
[Fix][FSDP] Don't remove post backward hooks for multiple backward fix (
Browse files Browse the repository at this point in the history
#1079)

* tmp

* test again

* test again

* add new test

* clean up

* add test file to the testlist

* more comments

* add changelog

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Sep 24, 2022
1 parent 8f8f8ef commit f4fcee7
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 33 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.4.7] - TBD
## [0.4.11] - TBD

- cleaned up some old issues and fixed a few bug in FSDP
- removing SSD offload to simplify the FSDP code

## [0.4.8]/[0.4.9]/[0.4.10]

### Added

Expand Down Expand Up @@ -48,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
parameter internally.

### Fixed
- fixed some bugs in FSDP related to supporting data2vec EMA modules.


[0.4.6] - 2022-03-08
Expand Down
71 changes: 39 additions & 32 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def append_shared_param(self, p: Parameter) -> None:
p must to be already sharded by the owning module.
Check the corresponding unit test to see how is it used and tested.
Check the corresponding unit tests to see how is it used and tested.
In particular, the sharing FSDP wrappers are "siblings" not "parent"
and "child" of each other in the nested module structure.
Expand Down Expand Up @@ -1700,35 +1700,38 @@ def _register_post_backward_hooks(self) -> None:
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
There are several cases here:
1. We can call the same module multiple times in a single outer forward
pass. We register multiple hooks but autograd should fire the last
one after the total gradient is computed and accumulated. If it does
fire multiple times, we may have a crash due to gradient being already
sharded and shape mismatch.
On the other hand, due to _saved_grad_shard, this case may also work
but with extra grad scatter-gather.
2. With activation checkpointing and case 1.
3. The same outer forward can be called multiple times before any backward
is called (within the no_sync context) for a special way of gradient
accumulation. (see test_fsdp_fwd_fwd_bwd_bwd.py)
4. When a param is shared by multiple FSDP wrapper instances, this can
register multiple times. (See test_fsdp_shared_weights.py)
It appears that registering the hook everytime and let them fire and
hook being removed/freed automatically is the correct thing to do. But this
is purely based on experiments.
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
continue
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
# Register a hook.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
# Important, we need to save the hook, otherwise, it appears to be
# deleted/freed/unregistered.
# However, we don't free/unhook at the end of bwd (as we used to do it
# in _finalize_parameters below). If we do, that may unregister the wrong hook.
p._shard_bwd_hook = (grad_acc, handle)

@torch.no_grad()
Expand Down Expand Up @@ -1756,20 +1759,26 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return

if hasattr(param, "_linked_param"):
# This links to a shared param. We should finalize the linked param here.
assert param.shape == (1,), param.shape
# This links to a shared param. We should try to finalize the linked param here.
# This is done by module code to ensure correct gradient computation.
# p._is_shared and p._linked_param are closely related but not the same.
# See fairscale/experimental/nn/mevo.py.
assert param.shape == (1,), param.shape # This param should have this special dim.
# If the _is_shared flag is set, then this shared weight is indeed being
# shared between different FSDP wrappers. Otherwise, they are linked but
# likely in the same FSDP wrapper, which means we shouldn't finalize the
# linked param..
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
# param._linked_param may or may not have .grad since this callback
# could happen multiple times to support #918. Since we check `if param.grad is None`
# below anyway, this is OK.
param = param._linked_param

assert param.grad is not None, param.shape
if param.grad is None:
return

if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

Expand Down Expand Up @@ -1950,10 +1959,6 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
for p in fsdp_module.params:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand All @@ -1969,11 +1974,13 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
elif hasattr(p, "_saved_grad_shard"):
p_assert(
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
f"WFPB: incorrect saved_grad_shard device p.device={p.device} "
f"vs p._saved_grad_shard.device={p._saved_grad_shard.device}",
)
p_assert(
p.shape == p._saved_grad_shard.shape,
f"WFPB: incorrect saved_grad_shard shape {p.shape} vs {p._saved_grad_shard.shape}",
f"WFPB: incorrect saved_grad_shard shape p.shape={p.shape} "
f"vs p._saved_grad_shard.shape={p._saved_grad_shard.shape}",
)
p.grad = p._saved_grad_shard

Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
72 changes: 72 additions & 0 deletions tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from fairscale.fair_dev.testing.testing import skip_if_single_gpu, temp_files_ctx
from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel


class FFN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
self.relu = nn.ReLU()

def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))


def main(rank, sync_file):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.set_device(rank)
torch.distributed.init_process_group(
backend="nccl",
init_method=f"file://{sync_file}",
world_size=2,
rank=rank,
)
ffn = FFN().cuda().half()

with enable_wrap(wrapper_cls=FullyShardedDataParallel):
model = wrap(
ffn,
process_group=torch.distributed.new_group(),
flatten_parameters=True,
compute_dtype=torch.float16,
)

model = model.train()

# We test this behavior because it might be used by pipelining.
# However, we don't check if the speed (compute/comm overlapping)
# and memory (necessary all-gather & free) are optimal.
losses = []
for _ in range(3):
x = torch.rand((10, 10)).cuda().half()
out = model(x)
loss = out.sum()
losses.append(loss)

# Only the last bwd can be outside of no_sync context.
with model.no_sync():
losses[0].backward()
losses[1].backward()
losses[2].backward()


@skip_if_single_gpu
def test_fwd_fwd_bwd_bwd():
with temp_files_ctx(num=1) as temp_files:
torch.multiprocessing.spawn(
fn=main,
nprocs=2,
args=(temp_files[0],),
join=True,
)

0 comments on commit f4fcee7

Please sign in to comment.