Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for distributed tests on pytorch>=1.12 #2141

Merged
merged 20 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
04219b6
added new DistributedTest class to replace broken distributed_test de…
mrwyattii Jul 26, 2022
acd03b1
Merge branch 'master' into dist-test-refactor
mrwyattii Jul 26, 2022
927f717
fix that broke old dist tests
mrwyattii Jul 26, 2022
b9125dd
fix for running test multiple times, extending parameterization support
mrwyattii Jul 28, 2022
5673c0e
improvements to ensure capturing all necessary pytest fixtures for a …
mrwyattii Jul 28, 2022
3efe06e
extended capabilities to have multiple test per DistributedTest subclass
mrwyattii Jul 28, 2022
a4f87cd
Merge branch 'master' into dist-test-refactor
mrwyattii Jul 28, 2022
8a3e214
converted more tests, updated inference tests to use latest pytorch
mrwyattii Jul 28, 2022
cbb5c92
fix broken refactors of tests
mrwyattii Jul 28, 2022
b5f31d9
fix for broken pytest.skip
mrwyattii Jul 29, 2022
c70b133
moving refactored tests to new subdirs
mrwyattii Jul 29, 2022
31dc986
moved more refactored tests
mrwyattii Jul 29, 2022
c2459d6
added __init__ for imports
mrwyattii Jul 29, 2022
1332bc2
Merge branch 'master' into dist-test-refactor
tjruwase Jul 30, 2022
efa483b
Merge branch 'master' into dist-test-refactor
tjruwase Jul 31, 2022
970198a
Merge branch 'master' into dist-test-refactor
mrwyattii Aug 1, 2022
d12fc85
update monitor tests
mrwyattii Aug 1, 2022
d6fa9c4
add 'future' to dev reqs for torch12 testing
mrwyattii Aug 1, 2022
2d6cbce
Update nv-transformers-v100.yml
mrwyattii Aug 1, 2022
d159a4b
Merge branch 'master' into dist-test-refactor
mrwyattii Aug 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ concurrency:

jobs:
unit-tests:
runs-on: [self-hosted, nvidia, cu111, v100]
runs-on: [self-hosted, nvidia, cu113, v100]

steps:
- uses: actions/checkout@v2
Expand All @@ -31,7 +31,7 @@ jobs:
nvcc --version
pip install --upgrade pip
pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"

Expand Down Expand Up @@ -60,4 +60,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'inference' unit/ --torch_ver="1.8" --cuda_ver="11.1"
EXPECTED_TORCH=$(pip index versions torch | grep -oP -m1 "^\s*LATEST.*\s\K\d+\.\d+")
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'inference' unit/ --torch_ver=$EXPECTED_TORCH --cuda_ver="11.3"
5 changes: 3 additions & 2 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
EXPECTED_TORCH=$(pip index versions torch | grep -oP -m1 "^\s*LATEST.*\s\K\d+\.\d+")
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ --torch_ver=$EXPECTED_TROCH --cuda_ver="11.3"
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/ --torch_ver=$EXPECTED_TORCH --cuda_ver="11.3"
Empty file added tests/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ def check_environment(pytestconfig):
pytest.exit(
f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
returncode=2)


# Override of pytest "runtest" for DistributedTest class
# This hook is run before the default pytest_runtest_call
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# We want to use our own launching function for distributed tests
if getattr(item.cls, "is_dist_test", False):
dist_test_class = item.cls()
dist_test_class._run_test(item._request)
item.runtest = lambda: True # Dummy function so test is not run twice
62 changes: 62 additions & 0 deletions tests/unit/comm/test_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import deepspeed.comm as dist

from tests.unit.common import DistributedTest

import pytest


class TestInit(DistributedTest):
world_size = 3

def test(self):
assert dist.is_initialized()
assert dist.get_world_size() == 3
assert dist.get_rank() < 3


# Demonstration of pytest's parameterization and fixtures
@pytest.fixture(params=["hello"])
def greeting(request):
return request.param


@pytest.mark.parametrize("number,color", [(1138, "purple")])
class TestDistArgs(DistributedTest):
world_size = 2
""" Classes that use DistributedTest class must define a test* method """
@pytest.mark.parametrize("shape", ["icosahedron"])
def test(self, number, color, shape, greeting):
"""Ensure that we can parse args to DistributedTest methods. """
assert dist.get_world_size() == 2
assert number == 1138
assert color == "purple"
assert shape == "icosahedron"
assert greeting == "hello"


# Demonstration of distributed tests grouped in single class
@pytest.mark.parametrize("number", [1138])
class TestGroupedDistTest(DistributedTest):
world_size = 2

def test_one(self, number):
assert dist.get_world_size() == 2
assert number == 1138

@pytest.mark.parametrize("color", ["purple"])
def test_two(self, number, color):
assert dist.get_world_size() == 2
assert number == 1138
assert color == "purple"


class TestDistAllReduce(DistributedTest):
world_size = [1, 2, 4]

def test(self):
x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).cuda() * sum_of_ranks
dist.all_reduce(x)
assert torch.all(x == result)
112 changes: 108 additions & 4 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import os
import time
import inspect
from abc import ABC
from pathlib import Path

import torch
import torch.multiprocessing as mp
import deepspeed
import deepspeed.comm as dist
from torch.multiprocessing import Process

import deepspeed

import pytest

from pathlib import Path
from _pytest.outcomes import Skipped

# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
Expand Down Expand Up @@ -60,6 +62,108 @@ def set_cuda_visibile():
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)


class DistributedTest(ABC):
is_dist_test = True
world_size = 2
backend = "nccl"

def _run_test(self, request):
self.current_test = self._get_current_test_func(request)
self.test_kwargs = self._get_test_kwargs(request)
if isinstance(self.world_size, int):
self.world_size = [self.world_size]
for procs in self.world_size:
self._launch_procs(procs)
time.sleep(0.5)

def _get_current_test_func(self, request):
# DistributedTest subclasses may have multiple test methods
func_name = request.function.__name__
return getattr(self, func_name)

def _get_test_kwargs(self, request):
# Grab fixture / parametrize kwargs from pytest request object
test_kwargs = {}
params = inspect.getfullargspec(self.current_test).args
params.remove("self")
for p in params:
test_kwargs[p] = request.getfixturevalue(p)
return test_kwargs

def _launch_procs(self, num_procs):
mp.set_start_method('forkserver', force=True)
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
processes = []
for local_rank in range(num_procs):
p = Process(target=self._dist_init, args=(local_rank, num_procs, skip_msg))
p.start()
processes.append(p)

# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
any_done = False
while not any_done:
for p in processes:
if not p.is_alive():
any_done = True
break

# Wait for all other processes to complete
for p in processes:
p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)

failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
for rank, p in failed:
# If it still hasn't terminated, kill it because it hung.
if p.exitcode is None:
p.terminate()
pytest.fail(f'Worker {rank} hung.', pytrace=False)
if p.exitcode < 0:
pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
pytrace=False)
if p.exitcode > 0:
pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
pytrace=False)

if not skip_msg.empty():
# This assumed all skip messages are the same, it may be useful to
# add a check here to assert all exit messages are equal
pytest.skip(skip_msg.get())

def _dist_init(self, local_rank, num_procs, skip_msg):
"""Initialize deepspeed.comm and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = get_master_port()
os.environ['LOCAL_RANK'] = str(local_rank)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os.environ['RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(num_procs)

# turn off NCCL logging if set
os.environ.pop('NCCL_DEBUG', None)

set_cuda_visibile()

deepspeed.init_distributed(dist_backend=self.backend)
dist.barrier()

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

try:
self.current_test(**self.test_kwargs)
except BaseException as e:
if isinstance(e, Skipped):
skip_msg.put(e.msg)
else:
raise e

# make sure all ranks finish at the same time
dist.barrier()
# tear down after test completes
dist.destroy_process_group()


def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
Expand Down