Skip to content

Commit

Permalink
Added reshard hook for frozen params in backward
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Jan 12, 2024
1 parent 3b7cc24 commit a4f02ef
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 11 deletions.
74 changes: 63 additions & 11 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from enum import Enum, auto
import functools
import itertools
import logging
from math import inf
import os
Expand Down Expand Up @@ -47,7 +48,6 @@
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
ProcessGroupName,
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
Expand Down Expand Up @@ -1457,6 +1457,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
self._register_post_backward_reshard_hooks(args, kwargs)

outputs = self.module(*args, **kwargs)

Expand Down Expand Up @@ -1655,6 +1656,37 @@ def _register_post_backward_hooks(self) -> None:
p._shard_bwd_hooks.append((grad_acc, handle))
# p._shard_bwd_hook = (grad_acc, handle)

def _register_post_backward_reshard_hooks(
self, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> None:
if not hasattr(torch.autograd.graph, "register_multi_grad_hook"):
return # unsupported
if not torch.is_grad_enabled():
return
from torch.utils._pytree import tree_flatten
from torch.autograd.graph import register_multi_grad_hook
# Construct `inp_tensors` lazily to avoid CPU overhead in typical case
# where each parameter requires gradient
inp_tensors: Optional[List[torch.Tensor]] = None
for param in self.params:
# Only register for parameters that do not require gradient
if param.requires_grad:
continue
if inp_tensors is None:
args_list, _ = tree_flatten(args)
kwargs_list, _ = tree_flatten(kwargs)
inp_tensors = [
obj
for obj in itertools.chain(args_list, kwargs_list)
if torch.is_tensor(obj) and obj.requires_grad
]
hook_handle = register_multi_grad_hook(
inp_tensors, functools.partial(self._post_backward_reshard_hook, param)
)
if not hasattr(param, "_shard_bwd_hooks"):
param._shard_bwd_hooks = []
param._shard_bwd_hooks.append((hook_handle,))

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
"""
Expand Down Expand Up @@ -1697,12 +1729,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
if self._should_free_in_backward():
# Free full params.
self._free_full_params([param])

if self.mixed_precision:
Expand Down Expand Up @@ -1829,6 +1857,22 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
# Don't let this memory get reused until after the transfer.
reduced_grad.data.record_stream(torch.cuda.current_stream())

@torch.no_grad()
def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None:
if self._should_free_in_backward():
self._free_full_params([param])
if self.mixed_precision:
self._free_fp16_param_shard([param])
self._use_fp32_param_shard([param])

def _should_free_in_backward(self):
# As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
return self._require_backward_grad_sync or self.reshard_after_forward

def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
Expand Down Expand Up @@ -1878,16 +1922,24 @@ def _wait_for_post_backward(self) -> None:
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()
for hook_state in p._shard_bwd_hooks:
if len(hook_state) == 1:
hook_state[0].remove()
elif len(hook_state) == 2:
hook_state[1].remove()
p._shard_bwd_hooks.clear()
if not p.requires_grad:
# For the 1st layer, if the forward inputs did not require
# gradient, then we cannot run a reshard hook for it, and
# we instead free here.
if p._full_param_padded.untyped_storage().size() > 0:
fsdp_module._post_backward_reshard_hook(p)
continue

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand Down
96 changes: 96 additions & 0 deletions tests/nn/data_parallel/test_fsdp_freezing_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from enum import Enum
from itertools import product
from unittest import mock
import copy
import tempfile

import pytest
Expand Down Expand Up @@ -275,3 +277,97 @@ def test_freezing_weights(temp_files, nested_trunk):
nprocs=world_size,
)
temp_file_idx += 3


@skip_if_single_gpu
def test_reshard_frozen_weights():
world_size = 2
for flatten_parameters, reshard_after_forward, inp_requires_grad in product(
[False, True], [False, True], [False, True]
):
print(
"Testing FSDP reshard frozen weights with "
f"flatten_parameters={flatten_parameters}, "
f"reshard_after_forward={reshard_after_forward}, "
f"inp_requires_grad={inp_requires_grad}"
)
mp.spawn(
_distributed_worker_reshard,
(world_size, flatten_parameters, reshard_after_forward, inp_requires_grad),
nprocs=world_size,
)


def _distributed_worker_reshard(
rank: int,
world_size: int,
flatten_parameters: bool,
reshard_after_forward: bool,
inp_requires_grad: bool,
):
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

torch.manual_seed(0)

num_linears = 6
modules = []
for _ in range(num_linears):
modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()]
model = nn.Sequential(*modules)
# Freeze every other linear
for i in range(num_linears):
if i % 2 == 0:
for param in model[i * 2].parameters(recurse=False):
param.requires_grad = False
num_frozen_linears = num_linears // 2

ref_model = DistributedDataParallel(copy.deepcopy(model), device_ids=[rank])
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)

for i, module in enumerate(model):
if isinstance(module, nn.Linear):
model[i] = FSDP(
module,
flatten_parameters=flatten_parameters,
reshard_after_forward=reshard_after_forward,
)
fsdp_model = FSDP(
model,
flatten_parameters=flatten_parameters,
reshard_after_forward=reshard_after_forward,
)
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2)

orig_post_backward_reshard_hook = FSDP._post_backward_reshard_hook
reshard_hook_count = 0

def post_backward_reshard_hook_with_count(*args, **kwargs):
nonlocal reshard_hook_count
reshard_hook_count += 1
return orig_post_backward_reshard_hook(*args, **kwargs)

with mock.patch(
"fairscale.nn.data_parallel.FullyShardedDataParallel._post_backward_reshard_hook",
post_backward_reshard_hook_with_count,
):
inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad)
for i in range(6):
losses = []
for model, optim in ((fsdp_model, fsdp_optim), (ref_model, ref_optim)):
optim.zero_grad()
loss = model(inp).sum()
losses.append(loss)
loss.backward()
optim.step()
expected_reshard_hook_count = num_frozen_linears
if not flatten_parameters:
expected_reshard_hook_count *= 2 # weight and bias per linear
assert (
reshard_hook_count == expected_reshard_hook_count
), f"Expected {expected_reshard_hook_count} but got {reshard_hook_count}"
assert losses[0].eq(losses[1]).all().item(), f"Expected {losses[1]} but got {losses[0]}"
reshard_hook_count = 0

0 comments on commit a4f02ef

Please sign in to comment.