Skip to content

Commit

Permalink
pytorch#2213 fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fco-dv committed Nov 2, 2021
1 parent 596b29e commit beca8c9
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
12 changes: 6 additions & 6 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,14 @@ def _xla_template_worker_task(index, fn, args):
fn(index, *args)


def _xla_execute(fn, args, nprocs, **kwargs):
def _xla_execute(fn, args, nprocs):

import torch_xla.distributed.xla_multiprocessing as xmp

spawn_kwargs = {}
if "COLAB_TPU_ADDR" in os.environ:
spawn_kwargs["start_method"] = "fork"
spawn_kwargs.update(kwargs)

try:
xmp.spawn(_xla_template_worker_task, args=(fn, args), nprocs=nprocs, **spawn_kwargs)
except SystemExit as ex_:
Expand Down Expand Up @@ -306,20 +306,20 @@ def _hvd_task_with_init(func, args):
hvd.shutdown()


def _gloo_hvd_execute(func, args, np=1, do_init=False, **kwargs):
def _gloo_hvd_execute(func, args, np=1, do_init=False):
try:
# old API
from horovod.run.runner import run
except ImportError:
# new API: https://github.com/horovod/horovod/pull/2099
from horovod import run

spawn_kwargs = dict(use_gloo=True, np=np, **kwargs)
kwargs = dict(use_gloo=True, np=np)

if do_init:
return run(_hvd_task_with_init, args=(func, args), **spawn_kwargs)
return run(_hvd_task_with_init, args=(func, args), **kwargs)

return run(func, args=args, **spawn_kwargs)
return run(func, args=args, **kwargs)


@pytest.fixture()
Expand Down
5 changes: 3 additions & 2 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ def _test(data_src, data_others, safe_mode):
idist.broadcast(None, src=0)


def _test_distrib_barrier(device, **kwargs):
def _test_distrib_barrier(device, kwargs_dict=None):

t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float)
true_res = sum([i for i in range(idist.get_world_size())])

if idist.get_rank() == 0:
t += 10.0
idist.barrier(**kwargs)

idist.barrier(**kwargs_dict) if kwargs_dict else idist.barrier()

tt = idist.all_reduce(t)
assert tt.item() == true_res + 10.0
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_idist_barrier_kwargs_hvd(gloo_hvd_executor):
postscale_factor=1.0,
process_set=global_process_set,
)
gloo_hvd_executor(_test_distrib_barrier, (device,), np=np, do_init=True, **kwargs_dict)
gloo_hvd_executor(_test_distrib_barrier, (device, kwargs_dict,), np=np, do_init=True)


def _test_idist_methods_overhead(ok_factor, sync_model):
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl):
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, **kwargs_dict)
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by nccl."):
_test_distrib_barrier(device, **kwargs_dict)
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.distributed
Expand All @@ -296,11 +296,11 @@ def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo):
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, **kwargs_dict)
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by gloo."):
_test_distrib_barrier(device, **kwargs_dict)
_test_distrib_barrier(device, kwargs_dict)


def _test_idist_methods_overhead(ok_factor):
Expand Down
8 changes: 4 additions & 4 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def test_idist_barrier_xla():
_test_distrib_barrier(device)


def _test_idist_barrier_xla_in_child_proc(index):
def _test_idist_barrier_xla_in_child_proc(index, kwargs_dict=None):
device = idist.device()
_test_distrib_barrier(device)
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.tpu
Expand All @@ -196,7 +196,7 @@ def test_idist_barrier_kwargs_xla():

device = idist.device()
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
_test_distrib_barrier(device, **kwargs_dict)
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.tpu
Expand All @@ -213,7 +213,7 @@ def test_idist_barrier_xla_in_child_proc(xmp_executor):
def test_idist_barrier_kwargs_xla_in_child_proc(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(), nprocs=n, **kwargs_dict)
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(kwargs_dict,), nprocs=n)


@pytest.mark.tpu
Expand Down

0 comments on commit beca8c9

Please sign in to comment.