diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 09648894c..d9b20fca7 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from enum import Enum, auto import functools +import itertools import logging from math import inf import os @@ -47,7 +48,6 @@ from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import ( ProcessGroupName, - chunk_and_pad, enable_pytorch_sync_bn, get_process_group_cached, validate_process_group, @@ -1457,6 +1457,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # Register backward hooks to reshard params and reduce-scatter grads. # These need to be re-registered every forward pass. self._register_post_backward_hooks() + self._register_post_backward_reshard_hooks(args, kwargs) outputs = self.module(*args, **kwargs) @@ -1655,6 +1656,37 @@ def _register_post_backward_hooks(self) -> None: p._shard_bwd_hooks.append((grad_acc, handle)) # p._shard_bwd_hook = (grad_acc, handle) + def _register_post_backward_reshard_hooks( + self, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> None: + if not hasattr(torch.autograd.graph, "register_multi_grad_hook"): + return # unsupported + if not torch.is_grad_enabled(): + return + from torch.utils._pytree import tree_flatten + from torch.autograd.graph import register_multi_grad_hook + # Construct `inp_tensors` lazily to avoid CPU overhead in typical case + # where each parameter requires gradient + inp_tensors: Optional[List[torch.Tensor]] = None + for param in self.params: + # Only register for parameters that do not require gradient + if param.requires_grad: + continue + if inp_tensors is None: + args_list, _ = tree_flatten(args) + kwargs_list, _ = tree_flatten(kwargs) + inp_tensors = [ + obj + for obj in itertools.chain(args_list, kwargs_list) + if torch.is_tensor(obj) and obj.requires_grad + ] + hook_handle = register_multi_grad_hook( + inp_tensors, functools.partial(self._post_backward_reshard_hook, param) + ) + if not hasattr(param, "_shard_bwd_hooks"): + param._shard_bwd_hooks = [] + param._shard_bwd_hooks.append((hook_handle,)) + @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: """ @@ -1697,12 +1729,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if param.grad.requires_grad: raise RuntimeError("FSDP only works with gradients that don't require gradients") - if self._require_backward_grad_sync or self.reshard_after_forward: - # Free full params. As a special case, we don't free the full params - # when in a ``no_sync`` context (as inversely indicated by - # ``self._require_backward_grad_sync``), since the params will not - # get updated before the next forward. This saves networking - # bandwidth but uses more GPU memory. + if self._should_free_in_backward(): + # Free full params. self._free_full_params([param]) if self.mixed_precision: @@ -1829,6 +1857,22 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> # Don't let this memory get reused until after the transfer. reduced_grad.data.record_stream(torch.cuda.current_stream()) + @torch.no_grad() + def _post_backward_reshard_hook(self, param: Parameter, *unused: Any) -> None: + if self._should_free_in_backward(): + self._free_full_params([param]) + if self.mixed_precision: + self._free_fp16_param_shard([param]) + self._use_fp32_param_shard([param]) + + def _should_free_in_backward(self): + # As a special case, we don't free the full params + # when in a ``no_sync`` context (as inversely indicated by + # ``self._require_backward_grad_sync``), since the params will not + # get updated before the next forward. This saves networking + # bandwidth but uses more GPU memory. + return self._require_backward_grad_sync or self.reshard_after_forward + def _queue_wait_for_post_backward(self) -> None: """Try to queue a `wait_for_post_backward` callback. @@ -1878,16 +1922,24 @@ def _wait_for_post_backward(self) -> None: def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: """Helper used below on all fsdp modules.""" for p in fsdp_module.params: - if not p.requires_grad: - continue if hasattr(p, "_shard_bwd_hook"): p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}") # p._shard_bwd_hook[1].remove() # delattr(p, "_shard_bwd_hook") if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync: - for _, handle in p._shard_bwd_hooks: - handle.remove() + for hook_state in p._shard_bwd_hooks: + if len(hook_state) == 1: + hook_state[0].remove() + elif len(hook_state) == 2: + hook_state[1].remove() p._shard_bwd_hooks.clear() + if not p.requires_grad: + # For the 1st layer, if the forward inputs did not require + # gradient, then we cannot run a reshard hook for it, and + # we instead free here. + if p._full_param_padded.untyped_storage().size() > 0: + fsdp_module._post_backward_reshard_hook(p) + continue # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad # remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard diff --git a/tests/nn/data_parallel/test_fsdp_freezing_weights.py b/tests/nn/data_parallel/test_fsdp_freezing_weights.py index c6ad364f7..7baadc5d9 100644 --- a/tests/nn/data_parallel/test_fsdp_freezing_weights.py +++ b/tests/nn/data_parallel/test_fsdp_freezing_weights.py @@ -12,6 +12,8 @@ from enum import Enum from itertools import product +from unittest import mock +import copy import tempfile import pytest @@ -275,3 +277,97 @@ def test_freezing_weights(temp_files, nested_trunk): nprocs=world_size, ) temp_file_idx += 3 + + +@skip_if_single_gpu +def test_reshard_frozen_weights(): + world_size = 2 + for flatten_parameters, reshard_after_forward, inp_requires_grad in product( + [False, True], [False, True], [False, True] + ): + print( + "Testing FSDP reshard frozen weights with " + f"flatten_parameters={flatten_parameters}, " + f"reshard_after_forward={reshard_after_forward}, " + f"inp_requires_grad={inp_requires_grad}" + ) + mp.spawn( + _distributed_worker_reshard, + (world_size, flatten_parameters, reshard_after_forward, inp_requires_grad), + nprocs=world_size, + ) + + +def _distributed_worker_reshard( + rank: int, + world_size: int, + flatten_parameters: bool, + reshard_after_forward: bool, + inp_requires_grad: bool, +): + import os + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.cuda.set_device(rank) + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + torch.manual_seed(0) + + num_linears = 6 + modules = [] + for _ in range(num_linears): + modules += [nn.Linear(5, 5, device="cuda"), nn.ReLU()] + model = nn.Sequential(*modules) + # Freeze every other linear + for i in range(num_linears): + if i % 2 == 0: + for param in model[i * 2].parameters(recurse=False): + param.requires_grad = False + num_frozen_linears = num_linears // 2 + + ref_model = DistributedDataParallel(copy.deepcopy(model), device_ids=[rank]) + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + for i, module in enumerate(model): + if isinstance(module, nn.Linear): + model[i] = FSDP( + module, + flatten_parameters=flatten_parameters, + reshard_after_forward=reshard_after_forward, + ) + fsdp_model = FSDP( + model, + flatten_parameters=flatten_parameters, + reshard_after_forward=reshard_after_forward, + ) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + orig_post_backward_reshard_hook = FSDP._post_backward_reshard_hook + reshard_hook_count = 0 + + def post_backward_reshard_hook_with_count(*args, **kwargs): + nonlocal reshard_hook_count + reshard_hook_count += 1 + return orig_post_backward_reshard_hook(*args, **kwargs) + + with mock.patch( + "fairscale.nn.data_parallel.FullyShardedDataParallel._post_backward_reshard_hook", + post_backward_reshard_hook_with_count, + ): + inp = torch.randn((8, 5), device="cuda", requires_grad=inp_requires_grad) + for i in range(6): + losses = [] + for model, optim in ((fsdp_model, fsdp_optim), (ref_model, ref_optim)): + optim.zero_grad() + loss = model(inp).sum() + losses.append(loss) + loss.backward() + optim.step() + expected_reshard_hook_count = num_frozen_linears + if not flatten_parameters: + expected_reshard_hook_count *= 2 # weight and bias per linear + assert ( + reshard_hook_count == expected_reshard_hook_count + ), f"Expected {expected_reshard_hook_count} but got {reshard_hook_count}" + assert losses[0].eq(losses[1]).all().item(), f"Expected {losses[1]} but got {losses[0]}" + reshard_hook_count = 0