From 6419d75608f77991b8ef936eb11cdf2a28351546 Mon Sep 17 00:00:00 2001 From: FrAnCOisCokELaER Date: Fri, 5 Nov 2021 18:33:08 +0100 Subject: [PATCH] #2213 refactor _check_barrier_fn_kwargs --- ignite/distributed/comp_models/base.py | 18 +++++++++++------- ignite/distributed/utils.py | 2 +- tests/ignite/distributed/utils/test_native.py | 8 ++++++-- tests/ignite/distributed/utils/test_xla.py | 3 ++- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 0144d4d16c5..aef14246e89 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -278,15 +278,19 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: - bnd_keys = set() - for param in signature(barrier_fn).parameters.values(): - if param.kind == param.POSITIONAL_OR_KEYWORD: - bnd_keys.add(param.name) - extra_keys = sorted(list(set(kwargs_dict) - bnd_keys)) + fn_params_name = set( + map( + lambda param: param.name, + filter( + lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values() + ), + ) + ) + extra_keys = kwargs_dict.keys() - fn_params_name if extra_keys: warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.") - for k in extra_keys: - del kwargs_dict[k] + for k in extra_keys: + del kwargs_dict[k] return kwargs_dict @abstractmethod diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 447142b2f8b..ca6c7db4f01 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -439,7 +439,7 @@ def barrier(**kwargs: Any) -> None: | Argument ``tag="barrier"`` is redefined. .. versionchanged:: 0.5.1 - Method now accepts ``kwargs`` for all supported backends. + Method now accepts ``kwargs`` for all supported backends. """ if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index 08890fb0928..3a49b836123 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -284,7 +284,9 @@ def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl): _test_distrib_barrier(device, kwargs_dict) kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []}) - with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by nccl."): + with pytest.warns( + UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl." + ): _test_distrib_barrier(device, kwargs_dict) @@ -299,7 +301,9 @@ def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo): _test_distrib_barrier(device, kwargs_dict) kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []}) - with pytest.warns(UserWarning, match=r"Extra keys : \['payload', 'replicas', 'tag'\] will not be used by gloo."): + with pytest.warns( + UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo." + ): _test_distrib_barrier(device, kwargs_dict) diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index c39178f00a4..b66ecc408a5 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -202,7 +202,8 @@ def test_idist_barrier_kwargs_xla(): kwargs_dict.update({"group": GroupMember.WORLD, "async_op": False, "device_ids": None}) with pytest.warns( - UserWarning, match=r"Extra keys : \['async_op', 'device_ids', 'group'\] will not be used by xla-tpu." + UserWarning, + match=r"Extra keys : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.", ): _test_distrib_barrier(device, kwargs_dict)