-
Notifications
You must be signed in to change notification settings - Fork 683
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
122 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
easybuild/easyconfigs/p/PyTorch/PyTorch-1.12.1_fix-skip-decorators.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
The decorators are implemented to run when the function is called which is after | ||
the test `setup` method spawned subprocesses which may use NCCL to sync failing when there are | ||
not enough GPUs available. | ||
So replace the custom code by calls to the `unittest` skip decorators. | ||
See https://github.com/pytorch/pytorch/pull/89750 | ||
|
||
Author: Alexander Grund (TU Dresden) | ||
|
||
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py | ||
index 8baf7d03d9f..3dc922ee923 100644 | ||
--- a/torch/testing/_internal/common_distributed.py | ||
+++ b/torch/testing/_internal/common_distributed.py | ||
@@ -124,46 +124,11 @@ def skip_if_odd_worldsize(func): | ||
return wrapper | ||
|
||
def require_n_gpus_for_nccl_backend(n, backend): | ||
- def decorator(func): | ||
- @wraps(func) | ||
- def wrapper(*args, **kwargs): | ||
- if backend == "nccl" and torch.cuda.device_count() < n: | ||
- sys.exit(TEST_SKIPS[f"multi-gpu-{n}"].exit_code) | ||
- else: | ||
- return func(*args, **kwargs) | ||
- | ||
- return wrapper | ||
- | ||
- return decorator | ||
+ return skip_if_lt_x_gpu(n) if backend == "nccl" else unittest.skipIf(False, None) | ||
|
||
|
||
def skip_if_lt_x_gpu(x): | ||
- def decorator(func): | ||
- @wraps(func) | ||
- def wrapper(*args, **kwargs): | ||
- if torch.cuda.is_available() and torch.cuda.device_count() >= x: | ||
- return func(*args, **kwargs) | ||
- sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) | ||
- | ||
- return wrapper | ||
- | ||
- return decorator | ||
- | ||
- | ||
-# This decorator helps avoiding initializing cuda while testing other backends | ||
-def nccl_skip_if_lt_x_gpu(backend, x): | ||
- def decorator(func): | ||
- @wraps(func) | ||
- def wrapper(*args, **kwargs): | ||
- if backend != "nccl": | ||
- return func(*args, **kwargs) | ||
- if torch.cuda.is_available() and torch.cuda.device_count() >= x: | ||
- return func(*args, **kwargs) | ||
- sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) | ||
- | ||
- return wrapper | ||
- | ||
- return decorator | ||
+ return unittest.skipIf(torch.cuda.device_count() < x, TEST_SKIPS[f"multi-gpu-{x}"].message) | ||
|
||
|
||
def verify_ddp_error_logged(model_DDP, err_substr): | ||
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py | ||
index 1414a0376b1..1f6b00e6edf 100644 | ||
--- a/torch/testing/_internal/distributed/distributed_test.py | ||
+++ b/torch/testing/_internal/distributed/distributed_test.py | ||
@@ -56,7 +56,6 @@ from torch.testing._internal.common_distributed import ( | ||
skip_if_small_worldsize, | ||
skip_if_odd_worldsize, | ||
skip_if_lt_x_gpu, | ||
- nccl_skip_if_lt_x_gpu, | ||
skip_if_no_gpu, | ||
require_n_gpus_for_nccl_backend, | ||
requires_nccl_version, | ||
@@ -4604,7 +4603,7 @@ class DistributedTest: | ||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", | ||
"get_future is only supported on mpi, nccl and gloo", | ||
) | ||
- @nccl_skip_if_lt_x_gpu(BACKEND, 2) | ||
+ @require_n_gpus_for_nccl_backend(BACKEND, 2) | ||
def test_accumulate_gradients_no_sync(self): | ||
""" | ||
Runs _test_accumulate_gradients_no_sync using default inputs | ||
@@ -4615,7 +4614,7 @@ class DistributedTest: | ||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", | ||
"get_future is only supported on mpi, nccl and gloo", | ||
) | ||
- @nccl_skip_if_lt_x_gpu(BACKEND, 2) | ||
+ @require_n_gpus_for_nccl_backend(BACKEND, 2) | ||
def test_accumulate_gradients_no_sync_grad_is_view(self): | ||
""" | ||
Runs _test_accumulate_gradients_no_sync using default inputs | ||
@@ -4626,7 +4625,7 @@ class DistributedTest: | ||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", | ||
"get_future is only supported on mpi, nccl and gloo", | ||
) | ||
- @nccl_skip_if_lt_x_gpu(BACKEND, 2) | ||
+ @require_n_gpus_for_nccl_backend(BACKEND, 2) | ||
def test_accumulate_gradients_no_sync_allreduce_hook(self): | ||
""" | ||
Runs multiple iterations on _test_accumulate_gradients_no_sync | ||
@@ -4654,7 +4653,7 @@ class DistributedTest: | ||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", | ||
"get_future is only supported on mpi, nccl and gloo", | ||
) | ||
- @nccl_skip_if_lt_x_gpu(BACKEND, 2) | ||
+ @require_n_gpus_for_nccl_backend(BACKEND, 2) | ||
def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self): | ||
""" | ||
Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce | ||
@@ -4688,7 +4687,7 @@ class DistributedTest: | ||
BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo", | ||
"get_future is only supported on mpi, nccl and gloo", | ||
) | ||
- @nccl_skip_if_lt_x_gpu(BACKEND, 2) | ||
+ @require_n_gpus_for_nccl_backend(BACKEND, 2) | ||
def test_get_future(self): | ||
def mult(fut): | ||
return [t * 3 for t in fut.wait()] |