Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 66 additions & 47 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import socket
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
Expand Down Expand Up @@ -182,6 +183,10 @@ def __init__(
self._pg = pg
self._manager: Optional[ManagerServer] = None

self._recovery_stream: Optional["torch.cuda.Stream"] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
)

if rank == 0:
if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
Expand Down Expand Up @@ -491,53 +496,63 @@ def _async_quorum(
self._quorum_id = quorum_id

if allow_heal:
if quorum.recover_dst_ranks:
self._logger.info(
f"peers need recovery from us {quorum.recover_dst_ranks}"
)
self._checkpoint_transport.send_checkpoint(
dst_ranks=quorum.recover_dst_ranks,
step=max_step,
state_dict=self._manager_state_dict(),
timeout=self._timeout,
)

# See manager.rs for healing conditions
if heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
)
primary_client = ManagerClient(
recover_src_manager_address, connect_timeout=self._connect_timeout
)
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
assert (
recover_src_rank is not None
), "must have a recover rank when healing"

self._logger.info(
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
)

# we apply the user state dict only when safe from the main thread
# save it for now
self._pending_state_dict = self._checkpoint_transport.recv_checkpoint(
src_rank=recover_src_rank,
metadata=checkpoint_metadata,
step=max_step,
timeout=self._timeout,
)

# pyre-fixme[6]: got object
self.load_state_dict(self._pending_state_dict["torchft"])

# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step
# run recovery on the recovery stream if available
recovery_stream = self._recovery_stream
with (
torch.cuda.stream(recovery_stream)
if recovery_stream is not None
else nullcontext()
):
if quorum.recover_dst_ranks:
self._logger.info(
f"peers need recovery from us {quorum.recover_dst_ranks}"
)
self._checkpoint_transport.send_checkpoint(
dst_ranks=quorum.recover_dst_ranks,
step=max_step,
state_dict=self._manager_state_dict(),
timeout=self._timeout,
)

# See manager.rs for healing conditions
if heal:
self._healing = True
self._logger.info(
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
)
primary_client = ManagerClient(
recover_src_manager_address,
connect_timeout=self._connect_timeout,
)
checkpoint_metadata = primary_client._checkpoint_metadata(
self._rank, timeout=self._timeout
)
recover_src_rank = quorum.recover_src_rank
assert (
recover_src_rank is not None
), "must have a recover rank when healing"

self._logger.info(
f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}"
)

# we apply the user state dict only when safe from the main thread
# save it for now
self._pending_state_dict = (
self._checkpoint_transport.recv_checkpoint(
src_rank=recover_src_rank,
metadata=checkpoint_metadata,
step=max_step,
timeout=self._timeout,
)
)

# pyre-fixme[6]: got object
self.load_state_dict(self._pending_state_dict["torchft"])

# This isn't strictly needed as loading the state_dict above should
# restore the correct step but it makes writing tests simpler.
self._step = max_step

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"
Expand Down Expand Up @@ -584,6 +599,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
# never return an error.
work.wait()

# make sure recovery is complete before committing
if self._recovery_stream is not None:
self._recovery_stream.synchronize()

self._pending_work = []

# apply state_dict if healing
Expand Down