diff --git a/torchft/manager.py b/torchft/manager.py index c11aab18..85af6235 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -31,6 +31,7 @@ import socket import uuid from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext from datetime import timedelta from enum import Enum from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast @@ -182,6 +183,10 @@ def __init__( self._pg = pg self._manager: Optional[ManagerServer] = None + self._recovery_stream: Optional["torch.cuda.Stream"] = ( + torch.cuda.Stream() if torch.cuda.is_available() else None + ) + if rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) @@ -491,53 +496,63 @@ def _async_quorum( self._quorum_id = quorum_id if allow_heal: - if quorum.recover_dst_ranks: - self._logger.info( - f"peers need recovery from us {quorum.recover_dst_ranks}" - ) - self._checkpoint_transport.send_checkpoint( - dst_ranks=quorum.recover_dst_ranks, - step=max_step, - state_dict=self._manager_state_dict(), - timeout=self._timeout, - ) - - # See manager.rs for healing conditions - if heal: - self._healing = True - self._logger.info( - f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" - ) - primary_client = ManagerClient( - recover_src_manager_address, connect_timeout=self._connect_timeout - ) - checkpoint_metadata = primary_client._checkpoint_metadata( - self._rank, timeout=self._timeout - ) - recover_src_rank = quorum.recover_src_rank - assert ( - recover_src_rank is not None - ), "must have a recover rank when healing" - - self._logger.info( - f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" - ) - - # we apply the user state dict only when safe from the main thread - # save it for now - self._pending_state_dict = self._checkpoint_transport.recv_checkpoint( - src_rank=recover_src_rank, - metadata=checkpoint_metadata, - step=max_step, - timeout=self._timeout, - ) - - # pyre-fixme[6]: got object - self.load_state_dict(self._pending_state_dict["torchft"]) - - # This isn't strictly needed as loading the state_dict above should - # restore the correct step but it makes writing tests simpler. - self._step = max_step + # run recovery on the recovery stream if available + recovery_stream = self._recovery_stream + with ( + torch.cuda.stream(recovery_stream) + if recovery_stream is not None + else nullcontext() + ): + if quorum.recover_dst_ranks: + self._logger.info( + f"peers need recovery from us {quorum.recover_dst_ranks}" + ) + self._checkpoint_transport.send_checkpoint( + dst_ranks=quorum.recover_dst_ranks, + step=max_step, + state_dict=self._manager_state_dict(), + timeout=self._timeout, + ) + + # See manager.rs for healing conditions + if heal: + self._healing = True + self._logger.info( + f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" + ) + primary_client = ManagerClient( + recover_src_manager_address, + connect_timeout=self._connect_timeout, + ) + checkpoint_metadata = primary_client._checkpoint_metadata( + self._rank, timeout=self._timeout + ) + recover_src_rank = quorum.recover_src_rank + assert ( + recover_src_rank is not None + ), "must have a recover rank when healing" + + self._logger.info( + f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" + ) + + # we apply the user state dict only when safe from the main thread + # save it for now + self._pending_state_dict = ( + self._checkpoint_transport.recv_checkpoint( + src_rank=recover_src_rank, + metadata=checkpoint_metadata, + step=max_step, + timeout=self._timeout, + ) + ) + + # pyre-fixme[6]: got object + self.load_state_dict(self._pending_state_dict["torchft"]) + + # This isn't strictly needed as loading the state_dict above should + # restore the correct step but it makes writing tests simpler. + self._step = max_step def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -584,6 +599,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: # never return an error. work.wait() + # make sure recovery is complete before committing + if self._recovery_stream is not None: + self._recovery_stream.synchronize() + self._pending_work = [] # apply state_dict if healing