Skip to content

Commit

Permalink
pytorch#2213 refactor _check_barrier_fn_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
fco-dv committed Nov 5, 2021
1 parent 5aa6ade commit 6419d75
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
18 changes: 11 additions & 7 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,19 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
bnd_keys = set()
for param in signature(barrier_fn).parameters.values():
if param.kind == param.POSITIONAL_OR_KEYWORD:
bnd_keys.add(param.name)
extra_keys = sorted(list(set(kwargs_dict) - bnd_keys))
fn_params_name = set(
map(
lambda param: param.name,
filter(
lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values()
),
)
)
extra_keys = kwargs_dict.keys() - fn_params_name
if extra_keys:
warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.")
for k in extra_keys:
del kwargs_dict[k]
for k in extra_keys:
del kwargs_dict[k]
return kwargs_dict

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def barrier(**kwargs: Any) -> None:
| Argument ``tag="barrier"`` is redefined.
.. versionchanged:: 0.5.1
Method now accepts ``kwargs`` for all supported backends.
Method now accepts ``kwargs`` for all supported backends.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)
Expand Down
8 changes: 6 additions & 2 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl):
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by nccl."):
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl."
):
_test_distrib_barrier(device, kwargs_dict)


Expand All @@ -299,7 +301,9 @@ def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo):
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by gloo."):
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo."
):
_test_distrib_barrier(device, kwargs_dict)


Expand Down
3 changes: 2 additions & 1 deletion tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def test_idist_barrier_kwargs_xla():

kwargs_dict.update({"group": GroupMember.WORLD, "async_op": False, "device_ids": None})
with pytest.warns(
UserWarning, match=r"Extra keys : \['async_op', 'device_ids', 'group'\] will not be used by xla-tpu."
UserWarning,
match=r"Extra keys : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.",
):
_test_distrib_barrier(device, kwargs_dict)

Expand Down

0 comments on commit 6419d75

Please sign in to comment.