Skip to content

Commit

Permalink
Merge pull request #11543 from nvcastet:fix_multigpu_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 462418103
  • Loading branch information
jax authors committed Jul 21, 2022
2 parents 1e05a1c + 7589c6d commit a4e7548
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tests/distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,20 @@ def test_gpu_distributed_initialize(self):
args = [
sys.executable,
"-c",
('"import jax, os; '
('import jax, os; '
'jax.distributed.initialize('
'f"localhost:{os.environ["JAX_PORT"]}", '
'os.environ["NUM_TASKS"], os.environ["TASK"])"'
'f\'localhost:{os.environ["JAX_PORT"]}\', '
'int(os.environ["NUM_TASKS"]), int(os.environ["TASK"])); '
'print(f\'{jax.local_device_count()},{jax.device_count()}\', end="")'
)
]
subprocesses.append(subprocess.Popen(args, env=env, shell=True))
subprocesses.append(subprocess.Popen(args, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, universal_newlines=True))

for i in range(num_tasks):
self.assertEqual(subprocesses[i].wait(), 0)
for proc in subprocesses:
out, _ = proc.communicate()
self.assertEqual(proc.returncode, 0)
self.assertEqual(out, f'{num_gpus_per_task},{num_gpus}')


if __name__ == "__main__":
Expand Down

0 comments on commit a4e7548

Please sign in to comment.