Skip to content

Commit

Permalink
added option to do backward AG over smaller set of gpus instead of fu…
Browse files Browse the repository at this point in the history
…ll DDP
  • Loading branch information
Naman Goyal committed May 20, 2023
1 parent ba38cf3 commit 0b77de4
Showing 1 changed file with 180 additions and 2 deletions.
182 changes: 180 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def __init__(
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
zero2_process_group: Optional[ProcessGroup] = None,
):
try:
import torch._C
Expand Down Expand Up @@ -380,6 +381,9 @@ def __init__(
"parameter uses all the available ranks for the optimal performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward

self.zero2_process_group = zero2_process_group

self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
Expand Down Expand Up @@ -518,6 +522,9 @@ def __init__(
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()

if self.zero2_process_group is not None:
assert not self.move_params_to_cpu

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
outputs = self.module(*args, **kwargs)

if self.reshard_after_forward:
self._free_full_params()
if self.zero2_process_group is not None:
self._zero2_shard_to_smaller_group()
else:
self._free_full_params()
if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard()

Expand Down Expand Up @@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None:
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
self._rebuild_full_params()
if self.zero2_process_group is not None:
self._zero2_rebuild_full_params()
else:
self._rebuild_full_params()
if (
self.reshard_after_forward
and self._fsdp_forward_ordering is not None
Expand Down Expand Up @@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors


@torch.no_grad()
def _zero2_rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []

def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None:
assert p._is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p._is_sharded:
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
assert p._fp16_shard is not None
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)

# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
if wait_for_all_gather:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
for p in self.params:
update_p_data()
return output_tensors

self.has_full_params = True

with torch.cuda.stream(self._streams["all_gather"]):

for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
continue
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group)

# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)

if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_zero2_param_shard([p])

if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_zero2_param_shard([p])
if wait_for_all_gather:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors


@torch.no_grad()
def _use_full_params(self) -> None:
"""
Expand Down Expand Up @@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()


def _zero2_shard_to_smaller_group(self, params: Optional[List[Parameter]] = None):
if params is None:
params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream()
for p in params:
if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard([p])
continue
# Cases for when zero2 world size > 1 but less than zero3 size
zero2_world_size = dist.get_world_size(self.zero2_process_group)
zero2_rank = dist.get_rank(self.zero2_process_group)
chunks = p._full_param_padded.chunk(zero2_world_size)

p._zero2_fp16_shard = torch.empty_like(chunks[zero2_rank])
p._zero2_fp16_shard.copy_(chunks[zero2_rank])

# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p._full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()


def local_metadata_dict(self) -> Dict[str, Any]:
"""
Get the information needed to reconstruct the model from shards offline.
Expand Down Expand Up @@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No
p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard)

@torch.no_grad()
def _free_zero2_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
if p._zero2_fp16_shard is not None:
# _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p._zero2_fp16_shard.record_stream(current_stream)
free_storage_(p._zero2_fp16_shard)

def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
Expand Down

0 comments on commit 0b77de4

Please sign in to comment.