Skip to content

Commit

Permalink
Fix test failure in GPU CI if NCCL_DEBUG is enabled.
Browse files Browse the repository at this point in the history
If NCCL_DEBUG is enabled, NCCL prints extra status information. Make
test accept this.
  • Loading branch information
hawkinsp committed Sep 28, 2022
1 parent 96abd9a commit eabb91e
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tests/multiprocess_gpu_test.py
Expand Up @@ -94,13 +94,12 @@ def test_gpu_distributed_initialize(self):
num_gpus_per_task = 1
num_tasks = num_gpus // num_gpus_per_task

os.environ["JAX_PORT"] = str(port)
os.environ["NUM_TASKS"] = str(num_tasks)

with contextlib.ExitStack() as exit_stack:
subprocesses = []
for task in range(num_tasks):
env = os.environ.copy()
env["JAX_PORT"] = str(port)
env["NUM_TASKS"] = str(num_tasks)
env["TASK"] = str(task)
env["CUDA_VISIBLE_DEVICES"] = ",".join(
str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task))
Expand Down Expand Up @@ -139,13 +138,12 @@ def test_distributed_jax_cuda_visible_devices(self):
num_gpus_per_task = 1
num_tasks = num_gpus // num_gpus_per_task

os.environ["JAX_PORT"] = str(port)
os.environ["NUM_TASKS"] = str(num_tasks)

with contextlib.ExitStack() as exit_stack:
subprocesses = []
for task in range(num_tasks):
env = os.environ.copy()
env["JAX_PORT"] = str(port)
env["NUM_TASKS"] = str(num_tasks)
env["TASK"] = str(task)
visible_devices = ",".join(
str((task * num_gpus_per_task) + i) for i in range(num_gpus_per_task))
Expand All @@ -167,7 +165,7 @@ def test_distributed_jax_cuda_visible_devices(self):
for proc in subprocesses:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus},[{num_gpus}.]')
self.assertRegex(out, f'{num_gpus_per_task},{num_gpus},\\[{num_gpus}.\\]$')
finally:
for proc in subprocesses:
proc.kill()
Expand Down

0 comments on commit eabb91e

Please sign in to comment.