Skip to content

Commit

Permalink
pytorch#2213 pass args and kwargs to idist.barrier method
Browse files Browse the repository at this point in the history
  • Loading branch information
fco-dv committed Dec 11, 2021
1 parent bc3e06d commit fb2b897
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 13 deletions.
21 changes: 18 additions & 3 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from inspect import signature
from numbers import Number
from typing import Any, Callable, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Union, cast

import torch

Expand Down Expand Up @@ -275,8 +276,22 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def _check_signature(self, fn: Callable, *args: Any, **kwargs: Any) -> None:
try:
fn_signature = signature(fn)
fn_signature.bind(*args, **kwargs)
except TypeError as exc:
fn_params = list(fn_signature.parameters)
exception_msg = str(exc)
passed_params = list(args) + list(kwargs)
raise ValueError(
f"Error calling {fn} for {self._backend}: "
f"takes parameters {fn_params} but will be called with {passed_params}"
f"({exception_msg})."
)

@abstractmethod
def barrier(self) -> None:
def barrier(self, *args: Any, **kwargs: Any) -> None:
pass


Expand Down Expand Up @@ -358,5 +373,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return tensor

def barrier(self) -> None:
def barrier(self, *args: Any, **kwargs: Any) -> None:
pass
18 changes: 14 additions & 4 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Mapping, Optional, Tuple, cast

import torch
from packaging import version

from ignite.distributed.comp_models.base import ComputationModel

Expand Down Expand Up @@ -194,7 +195,16 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)

def barrier(self) -> None:
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
def barrier(self, *args: Any, **kwargs: Any) -> None:
if version.parse(hvd.__version__) < version.parse("0.23.0"):
if len(args) or len(kwargs):
warnings.warn(
f"Arguments {list(args) + list(kwargs)} are not passed to horovod barrier method. "
f"Please use horovod version > '0.23.0'"
)
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
else:
self._check_signature(hvd.barrier, *args, **kwargs)
hvd.barrier(*args, **kwargs)
5 changes: 3 additions & 2 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,9 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
dist.broadcast(tensor, src=src)
return tensor

def barrier(self) -> None:
dist.barrier()
def barrier(self, *args: Any, **kwargs: Any) -> None:
self._check_signature(dist.barrier, *args, **kwargs)
dist.barrier(*args, **kwargs)

def _expand_hostlist(nodelist: str) -> List[str]:
"""Expand a compressed hostlist string and returns all hosts listed.
Expand Down
12 changes: 10 additions & 2 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Mapping, Optional, Tuple, cast

import torch
Expand Down Expand Up @@ -160,5 +161,12 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
xm.all_reduce("sum", [tensor,])
return tensor

def barrier(self) -> None:
xm.rendezvous("barrier")
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not len(args) and "tag" not in kwargs:
warnings.warn(
f"`tag` parameter is mandatory and is set by default to `barrier` for {self._backend} `rendezvous`"
f" method."
)
kwargs["tag"] = "barrier"
self._check_signature(xm.rendezvous, *args, **kwargs)
xm.rendezvous(*args, **kwargs)
19 changes: 17 additions & 2 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,28 @@ def broadcast(
return _model.broadcast(tensor, src=src, safe_mode=safe_mode)


def barrier() -> None:
def barrier(*args: Any, **kwargs: Any) -> None:
"""Helper method to synchronize all processes.
Args:
args: acceptable kwargs according to provided backend
kwargs: acceptable kwargs according to provided backend
- | "nccl" or "gloo" : ``group`` (default, GroupMember.WORLD), ``async_op`` (default, False),
| ``device_ids`` (default, None).
- | "horovod" : for ``horovod__version__>="0.23.0"``, ``process_set`` (default, global_process_set).
- | "xla-tpu" : ``tag``, ``payload`` (default, b""), ``replicas`` (default, []).
.. versionchanged:: 0.5.1
Method now accepts ``args`` and ``kwargs`` for all supported backends.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

_model.barrier()
_model.barrier(*args, **kwargs)


def set_local_rank(index: int) -> None:
Expand Down

0 comments on commit fb2b897

Please sign in to comment.