Skip to content

Commit

Permalink
[chore] [cleanup]: pytest, pytorch new versions, fix tests (#933)
Browse files Browse the repository at this point in the history
* update pytest versions

* [test] test related changes

- upgrade to newer pytorch versions
- added function to make test more deterministic on A100 and TF32
- fixed some tests so that they are correctly skipped on a single GPU system

* more fixes

* formatting overly long lines

* format

* better test without trigger a warning

* fix an optim state bug with newer pytorch

- adam optimizer seems to return "step" as a singleton tensor now in the
nightly build
- this fixes it assumeing non-tensor value can still be loaded back by
the optimizer

* improve oss.py

- use min_loss for regression checking is a bit more reliable
- also increased the num epochs from 10 to 12

* small oss.py fix

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Feb 14, 2022
1 parent 8527c58 commit fae2995
Show file tree
Hide file tree
Showing 18 changed files with 117 additions and 60 deletions.
34 changes: 17 additions & 17 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ setup_venv: &setup_venv
pip install --upgrade pip
# most recent LTS version
install_dep_1_8_1: &install_dep_1_8_1
install_dep_1_8_2: &install_dep_1_8_2
- run:
name: Install Dependencies with torch 1.8.1 (LTS)
name: Install Dependencies with torch 1.8.2 (LTS)
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install --progress-bar off torch==1.8.2+cu102 torchvision==0.9.2+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand All @@ -95,14 +95,14 @@ install_dep_1_8_1: &install_dep_1_8_1
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
# most recent stable version
install_dep_1_10_0: &install_dep_1_10_0
install_dep_1_10_2: &install_dep_1_10_2
- run:
name: Install Dependencies with torch 1.10.0
name: Install Dependencies with torch 1.10.2
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.10.0+cu111 torchvision==0.11.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.10.2+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand All @@ -115,13 +115,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
name: Install Dependencies with a torch nightly preview build
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.10 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
# start installing
pip install --progress-bar off --pre torch==1.11.0.dev20211231+cu111 torchvision==0.12.0.dev20211231+cu111 -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --progress-bar off --pre torch==1.12.0.dev20220210+cu113 torchvision==0.13.0.dev20220210+cu113 -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "11"], "wrong torch version"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "12"], "wrong torch version"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
Expand Down Expand Up @@ -161,7 +161,7 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark
command: |
python benchmarks/oss.py --world_size 4 --epochs 2
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp
python benchmarks/oss.py --check_regression --world_size 4 --optim_type oss_sharded_ddp --epochs 12
run_oss_gloo: &run_oss_gloo
- run:
Expand Down Expand Up @@ -249,7 +249,7 @@ jobs:
keys:
- cache-key-cpu-py37-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}

- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down Expand Up @@ -277,7 +277,7 @@ jobs:
- restore_cache:
keys:
- cache-key-cpu-py38-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}
- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down Expand Up @@ -306,7 +306,7 @@ jobs:
keys:
- cache-key-cpu-py39-torch-1-10-0-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}

- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down Expand Up @@ -346,7 +346,7 @@ jobs:
keys:
- cache-key-py-3-9-7-gpu-torch-1-8-1-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}

- <<: *install_dep_1_8_1
- <<: *install_dep_1_8_2

- save_cache:
paths:
Expand Down Expand Up @@ -389,7 +389,7 @@ jobs:
keys:
- cache-key-py-3-9-7-gpu-torch-1-10-0-cuda-11-2-{{.Environment.CACHE_VERSION }}-{{checksum "setup.py"}}-{{checksum "requirements-dev.txt"}}

- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down Expand Up @@ -470,7 +470,7 @@ jobs:
keys:
- cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}

- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down Expand Up @@ -520,7 +520,7 @@ jobs:
keys:
- cache-key-benchmark-MNIST-{{.Environment.CACHE_VERSION }}-{{checksum "benchmarks/datasets/mnist.py"}}

- <<: *install_dep_1_10_0
- <<: *install_dep_1_10_2

- save_cache:
paths:
Expand Down
15 changes: 12 additions & 3 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def validate_benchmark(measurements, final_loss, args, check_regression):
assert max_memory < 1.05 * golden_data["reference_memory"], (
f"Memory use regression detected: " f"{max_memory} vs. {1.05* golden_data['reference_memory']}"
)
assert abs(cast(float, final_loss) - golden_data["reference_loss"]) < 1e-2, (
# any min_loss < than golden + epsilon is OK.
assert cast(float, final_loss) - golden_data["reference_loss"] < 1e-2, (
f"Loss regression detected: " f"{final_loss} vs. {golden_data['reference_loss']}"
)
logging.info("[Regression Test] VALID")
Expand Down Expand Up @@ -176,6 +177,7 @@ def train(

measurements = []
final_loss: Optional[float] = -1.0
min_loss = 100.0
need_profiling = args.profile

for epoch in range(args.epochs):
Expand Down Expand Up @@ -264,14 +266,21 @@ def run_closure(closure, scaler, optimizer):
logging.info("... State dict collected")

measurements.append(n_items / epoch_runtime)
min_loss = min(min_loss, final_loss)
if dist.get_rank() == 0:
logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")
logging.info(
f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. "
f"Loss {final_loss:.3f} min loss {min_loss:.3f}"
)

training_stop = time.monotonic()
img_per_sec = n_items / (training_stop - training_start) * args.epochs
logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")

validate_benchmark(measurements, final_loss, args, check_regression)
# Use min_loss to check instead of final_loss since the final_loss is a bit random.
# If the training min_loss reaches certain number, we can be reasonably certain the
# training process was correct.
validate_benchmark(measurements, min_loss, args, check_regression)

dist.destroy_process_group() # type: ignore

Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _unflatten_optim_state(
if not combined_state:
return {}, global_to_local_id

# copy non tensor state to all global entries
# copy non tensor state (like the "step" count) to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}

if non_tensor_state[0].keys() == combined_state[0].keys():
Expand Down
25 changes: 21 additions & 4 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,16 @@ def __init__(
# In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
if self.process_group_reduce_scatter is not None:
reduce_scatter_group_size = self.process_group_reduce_scatter.size()
# Roll back to use the default process group for reduce scatter operation when the world size and reduce scatter process group size are differnt.
# Roll back to use the default process group for reduce scatter operation when
# the world size and reduce scatter process group size are differnt.
if self.world_size != reduce_scatter_group_size:
self.process_group_reduce_scatter = self.process_group
logging.warn(
"Rolled back to use the default process group for the reduce scatter operation because the reduce_scatter process group"
f"size is {reduce_scatter_group_size}, which is different with the world size {self.world_size}. Please make sure the process_group"
"parameter uses all the available ranks for the optimized performance."
"Rolled back to use the default process group for the reduce scatter "
"operation because the reduce_scatter process group "
f"size is {reduce_scatter_group_size}, which is different with the "
f"world size {self.world_size}. Please make sure the process_group "
"parameter uses all the available ranks for the optimal performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
Expand Down Expand Up @@ -2309,6 +2312,20 @@ def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has
``no_broadcast_optim_state`` flag set.
We also make rooms for the optimizer state on rank 0.
"""
# In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton
# tensor. We convert it back here. Otherwise, the step counter will be treated
# like a singleton tensor and comparison with original state dict would fail.
for _, bufs in osd["state"].items():
if "step" in bufs.keys():
assert type(bufs["step"]) is int or ou.is_singleton_tensor(bufs["step"])
if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item()
# Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
Expand Down
15 changes: 11 additions & 4 deletions fairscale/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def get_smi_ver() -> str:
return tuple(int(n) for n in numbering)


def make_cudnn_deterministic() -> None:
"""Make cudnn (matmul) deterministic"""
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# TF32 also make things nondeterministic. Disable it.
torch.backends.cuda.matmul.allow_tf32 = False # type: ignore
torch.backends.cudnn.allow_tf32 = False # type: ignore


def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
"""
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
Expand Down Expand Up @@ -218,8 +227,7 @@ def test_runner(
) -> None:
# At this point we're in a new process, torch options need to be set again
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
make_cudnn_deterministic()
torch.manual_seed(1357)

test_func(rank, *args, **kwargs)
Expand Down Expand Up @@ -270,8 +278,7 @@ def worker_process(
)

if torch.cuda.is_available() and not hasattr(torch.backends.cudnn, "flags"):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
make_cudnn_deterministic()

try:
with context:
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ mypy == 0.910
pre-commit >= 2.15.0

# Tools for unit tests & coverage.
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-timeout == 1.4.2
pytest == 7.0.0
pytest-cov == 3.0.0
pytest-timeout == 2.1.0
remote-pdb >= 2.1.0
parameterized >= 0.8.1

Expand Down
9 changes: 9 additions & 0 deletions tests/experimental/nn/test_multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_single_gpu

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 9, 0),
Expand Down Expand Up @@ -103,20 +104,23 @@ def create(devices):


@rpc_test()
@skip_if_single_gpu
def create_multiple_layers():
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def create_multiple_workers(devices):
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])


@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def parameter_rrefs(devices):
model = [RemoteModuleParams(nn.Linear, (4, 4), {}), RemoteModuleParams(nn.ReLU, (), {})]
pipe = create_sequence_pipeline(model, balance=[1, 1], chunks=1, devices=devices[:2])
Expand Down Expand Up @@ -149,6 +153,7 @@ def forward_chunks(devices):
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@skip_if_single_gpu
def forward_multi(devices, checkpoint):
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
Expand All @@ -166,6 +171,7 @@ def forward_multi(devices, checkpoint):

@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def backward(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
Expand All @@ -183,6 +189,7 @@ def backward(devices):

@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def update(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
Expand Down Expand Up @@ -223,6 +230,7 @@ def extract_partitions(graph: PipelineModulesGraph, pipeline: DistributedPipelin

@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def multi_input_multi_output_layers(devices):
device = devices[0].split("/")[1]
torch.random.manual_seed(3)
Expand Down Expand Up @@ -289,6 +297,7 @@ def forward(self, input):

@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
@skip_if_single_gpu
def auto_graph_extract(devices):
from fairscale.experimental.nn.distributed_pipeline.trace import make_graph

Expand Down
7 changes: 5 additions & 2 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,11 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

if wrapper_config is not None:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)
# we also need to pass that group as the reduce_scatter group.
expert_group = torch.distributed.new_group([group.rank()])
expert = FullyShardedDataParallel(
expert, process_group=expert_group, process_group_reduce_scatter=expert_group, **wrapper_config
)

shared = FullyShardedDataParallel(shared, group, **wrapper_config)

Expand Down
3 changes: 2 additions & 1 deletion tests/nn/data_parallel/test_fsdp_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from fairscale.utils.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal

from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init

Expand Down Expand Up @@ -64,6 +64,7 @@ def _test_nested_wrapper(self, rank, group, config):

@classmethod
def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True):
make_cudnn_deterministic()
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync
# or without no_sync) vs. concatenating the batches and training in one go.
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, x):
# We use strings for precision and flatten instead of bool to
# make the pytest output more readable.
@skip_if_no_cuda
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [1, 2] if torch.cuda.device_count() > 1 else [1])
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test(world_size, precision, flatten):
Expand Down

0 comments on commit fae2995

Please sign in to comment.