Skip to content

Commit

Permalink
[fix] FSDP forward pass overlap between compute and all-gather (#671)
Browse files Browse the repository at this point in the history
* [fix] FSDP forward pass overlap between compute and all-gather

- much thanks for @cyanguwa for report and @QuentinDuval for debugging it
- a new unit test is added to check for this and ensure we detect
  issue with overlapping and cpu/gpu blocking wait calls

* fix

* fix

* fix

* better assertion outputs

* fix format and tune all_gather mb for CI

* more tuning with non_flatten

* undo an accidental change

* tuning all gather mb and del model

* Update + fix overlapping test to use patched all_gather w/ delay (#672)

* fixing get_cycles_per_ms

* add get_smi_memory

* update the docstring

Co-authored-by: Min Xu <min.xu@acm.org>
Co-authored-by: Myle Ott <myleott@fb.com>
  • Loading branch information
3 people committed May 11, 2021
1 parent c8d32c3 commit 8a42a8e
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## NEXT - TBD
### Fixed
- FSDP: fix forward pass not overlapping compute and all-gather
- FSDP: improved frozen weight support
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
Expand Down
41 changes: 25 additions & 16 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Expand Up @@ -1448,20 +1448,22 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
if params is None:
params = self.params
self.has_full_params = False
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["all_gather"]):
for p in params:
if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision:
self._free_fp16_param_shard([p])
continue
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
current_stream = torch.cuda.current_stream()
for p in params:
if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision:
self._free_fp16_param_shard([p])
continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p._full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)

@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
Expand Down Expand Up @@ -1692,16 +1694,23 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device:
def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
We also retain the requires_grad flag so that casting doesn't affect the autograd graph.
"""

def fn_fp16(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32:
return x.half()
y = x.half()
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
return x

def fn_fp32(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float16:
return x.float()
y = x.float()
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
return x

fn = fn_fp16 if to_fp16 else fn_fp32
Expand Down
60 changes: 39 additions & 21 deletions fairscale/utils/testing.py
Expand Up @@ -34,6 +34,7 @@
import multiprocessing
import os
import random
from statistics import mean
import subprocess
import sys
import tempfile
Expand Down Expand Up @@ -577,30 +578,35 @@ def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:

@functools.lru_cache()
def get_cycles_per_ms() -> float:
"""Approximate number of cycles per millisecond for torch.cuda._sleep
"""Measure and return approximate number of cycles per millisecond for torch.cuda._sleep
Copied from: github.com/pytorch/pytorch/blob/master/test/test_cuda.py
..note::
This doesn't seems to return consistent cycles on desktop GPUs likely
due to frequency scaling.
>>> get_cycles_per_ms()
227.6441091140009
# new python process
>>> get_cycles_per_ms()
564.652154766248
# new python process
>>> get_cycles_per_ms()
245.56459442962856
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms

def measure() -> float:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return cycles_per_ms

# Get 10 values and remove the 2 max and 2 min and return the avg.
# This is to avoid system disturbance that skew the results, e.g.
# the very first cuda call likely does a bunch of init, which takes
# much longer than subsequent calls.
#
# Tested on both Tesla V100, Quadro GP100, Titan RTX, RTX 3090 GPUs
# and seems to return stable values. Therefore, we enable caching
# using lru_cache decorator above.
num = 10
vals = []
for _ in range(num):
vals.append(measure())
vals = sorted(vals)
return mean(vals[2 : num - 2])


class DummyProcessGroup:
Expand Down Expand Up @@ -681,3 +687,15 @@ def dump_all_tensors(rank: int) -> None:
except Exception as e:
pass
print(torch.cuda.memory_summary())


def get_smi_memory() -> float:
"""Return process's GPU memory in MB."""
pid = os.getpid()
info_string = torch.cuda.list_gpu_processes()
for line in info_string.splitlines():
if str(pid) in line:
toks = line.split()
return float(toks[3])
# If the process is not in the list, we are not using the GPU.
return 0.0
3 changes: 3 additions & 0 deletions requirements-test.txt
Expand Up @@ -13,3 +13,6 @@ pytest-cov == 2.10.0
pytest-timeout == 1.4.2
remote-pdb >= 2.1.0
parameterized >= 0.8.1

# For torch.cuda.list_gpu_processes()
pynvml == 8.0.4
2 changes: 2 additions & 0 deletions stubs/torch/cuda/__init__.pyi
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import torch
from typing import Optional, Tuple, Union, Dict, Any
import ctypes
from . import amp
Expand Down Expand Up @@ -48,6 +49,7 @@ def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
def memory_summary() -> str: ...
def cudart() -> ctypes.CDLL: ...
def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ...
def list_gpu_processes(device: Union[torch.device, str, None, int] = None) -> str: ...
#MODIFIED BY TORCHGPIPE
from .. import ByteTensor
def set_rng_state(new_state: ByteTensor, device: _device_t = ...) -> None: ...
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_2.txt
@@ -1,5 +1,6 @@
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
Expand Down

0 comments on commit 8a42a8e

Please sign in to comment.