Skip to content

Commit

Permalink
[minor] make backward assert a bit better (#919)
Browse files Browse the repository at this point in the history
* [minor] better assert in backward

* mypy

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jan 25, 2022
1 parent 5d8a505 commit 8ba649e
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,8 @@ def _queue_wait_for_post_backward(self) -> None:
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance."""
assert self._is_root
# None, backward runtime swallow the assert error, so we use p_assert() here.
p_assert(self._is_root, "WFPB not called on root")
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
Expand All @@ -1729,7 +1730,8 @@ def _wait_for_post_backward(self) -> None:
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
p_assert(self._reducer is not None, "WFPB: reducer is None")
assert self._reducer is not None # make mypy happy
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
Expand All @@ -1748,7 +1750,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")

Expand All @@ -1761,10 +1763,13 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:

# Parameter and gradient devices must match.
if hasattr(p, "_cpu_grad"):
assert p.device == torch.device("cpu")
p_assert(p.device == torch.device("cpu"), f"WFPB: incorrect cpu_grad device {p.device}")
p.grad = p._cpu_grad
elif hasattr(p, "_saved_grad_shard"):
assert p.device == p._saved_grad_shard.device
p_assert(
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
Expand Down Expand Up @@ -1799,7 +1804,11 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
assert self._output_pre_backward_hook_registered is not None
p_assert(
self._output_pre_backward_hook_registered is not None,
"WFPB: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()

@torch.no_grad()
Expand Down Expand Up @@ -2355,6 +2364,13 @@ def cpu_offload(self) -> bool:
return self.move_params_to_cpu


def p_assert(cond: Any, s: Any) -> None:
"""Used in backward context to make sure error is printed."""
if not cond:
print(s)
raise AssertionError


def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
try:
Expand Down

0 comments on commit 8ba649e

Please sign in to comment.