Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchft/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self.replicate_dim_name: str = mesh_dim_names[replicate_dim]
self.parent = parent
self.flatten_meshes: Dict[str, DeviceMesh] = {}
self._flatten_mapping: Dict[str, "DeviceMesh"] = {}
self._device_type: str
if mesh is not None:
self._device_type = mesh.device_type
Expand Down
61 changes: 52 additions & 9 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
# crash if call to quorum fails, all replicas will crash.
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"

TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"

T = TypeVar("T")


Expand All @@ -109,6 +111,17 @@ def get_timeout(
return default_timeout_sec


def extract_trailing_digits(s: str) -> int:
"""
Extracts the trailing digits from the end of the string s.
Returns an empty string if no trailing digits are found.
"""
i = len(s) - 1
while i >= 0 and s[i].isdigit():
i -= 1
return int(s[i + 1 :]) if i < len(s) - 1 else 0


class WorldSizeMode(Enum):
"""
This controls the numerics for the job when doing allreduces across replicas
Expand Down Expand Up @@ -223,6 +236,9 @@ def __init__(
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
self._user_state_dicts: Dict[str, Callable[[], object]] = {}

self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
TORCH_FR_DUMP_TEMP_FILE_ENV
)
self._replica_id = replica_id

# Protects state dict
Expand Down Expand Up @@ -257,7 +273,7 @@ def __init__(
store_port = store_port or int(os.environ["MASTER_PORT"])
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
group_rank = self._group_rank
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

if checkpoint_transport is None:
Expand Down Expand Up @@ -310,7 +326,7 @@ def __init__(
hostname=hostname,
bind=bind,
store_addr=f"{store_addr}:{store_port}",
world_size=group_world_size,
world_size=self._group_world_size,
heartbeat_interval=heartbeat_interval,
connect_timeout=connect_timeout,
quorum_retries=self._quorum_retries,
Expand Down Expand Up @@ -338,6 +354,8 @@ def __init__(
self._participating_replica_world_size: int = 0
self._is_state_dict_read_allowed = True

self._update_fr_path()

def allow_state_dict_read(self) -> None:
if self._is_state_dict_read_allowed:
return
Expand Down Expand Up @@ -665,16 +683,21 @@ def _async_quorum(
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
# We use the replica rank and world as we want all replicas in the PG.
try:
self._quorum_id = quorum_id
with torch.profiler.record_function("torchft::manager::_pg::configure"):
# Reset GPU state for Flight Recorder
if torch.accelerator.is_available():
torch.accelerator.synchronize()
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
self._update_fr_path()

self._pg.configure(
store_prefixed_addr,
self._replica_id if self._replica_id is not None else "0",
replica_rank,
replica_world_size,
quorum_id,
)
self._quorum_id = quorum_id
except Exception as e:
self._logger.exception(f"got exception in pg configure: {e}")
self.report_error(e)
Expand Down Expand Up @@ -749,6 +772,21 @@ def _async_quorum(
else None
)

def _update_fr_path(self) -> None:
if self._original_fr_dump_temp_file is not None:
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
os.makedirs(folder, exist_ok=True)

filename = (
self._group_rank
if self._replica_id is None
else (
extract_trailing_digits(self._replica_id) * self._group_world_size
+ self._group_rank
)
)
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{filename}"

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"

Expand Down Expand Up @@ -1244,14 +1282,19 @@ def _assert_same_stream(self) -> None:
def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._assert_same_stream()

with get_stream_context(self._stream):
self._work.wait()
self._set_future_callback()
try:
with get_stream_context(self._stream):
self._work.wait()
self._set_future_callback()

with get_stream_context(self._stream):
self._managed_fut_tail.wait()
with get_stream_context(self._stream):
self._managed_fut_tail.wait()

return True
return True
except Exception as e:
self._manager._logger.exception(f"got exception waiting for work {e}")
self._manager.report_error(e)
return False

def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
self._assert_same_stream()
Expand Down
54 changes: 45 additions & 9 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
raise NotImplementedError("not implemented")

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
"""
This reconfigures the ProcessGroup to use a new store, rank and world size.
Expand Down Expand Up @@ -408,6 +413,7 @@ def __init__(
self._timeout = timeout
self._replica_id: str | None = None
self._rank: int | None = None
self._quorum_id: int | None = None

self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")

Expand All @@ -419,13 +425,19 @@ def getBackendName(self) -> str:
raise NotImplementedError("not implemented")

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
pg = self._pg
self._replica_id = replica_id
self._quorum_id = quorum_id
self._rank = rank
if isinstance(pg, ProcessGroup):
pg.configure(store_addr, replica_id, rank, world_size)
pg.configure(store_addr, replica_id, rank, world_size, quorum_id)
return

# abort if already initialized
Expand All @@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None:
"job_id": os.environ.get("JOB_ID", "unknown"),
"replica_id": self._replica_id,
"rank": self._rank,
"quorum_id": self._quorum_id,
"error": "process_group_abort",
},
)
Expand Down Expand Up @@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
# pyre-fixme[16]: no attribute ProcessGroupGloo
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
backend_class._set_sequence_number_for_group()
backend_class.options.global_ranks_in_group = list(range(world_size))
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}"
pg._register_backend(
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
)
Expand Down Expand Up @@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
opts = BaseProcessGroupNCCL.Options()
opts.config.blocking = False
opts.global_ranks_in_group = list(range(world_size))
opts.group_name = f"torchft_quorum_{self._quorum_id}"

pg = BaseProcessGroup(store, rank, world_size)
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
Expand Down Expand Up @@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None:
self.configure_count = 0

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
self.configure_count += 1

Expand Down Expand Up @@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None:
self._error: Optional[Exception] = None

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
self._error = None

super().configure(store_addr, replica_id, rank, world_size)
super().configure(store_addr, replica_id, rank, world_size, quorum_id)

def report_error(self, e: Exception) -> None:
"""
Expand Down Expand Up @@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None:
self._future_error: Optional[Exception] = None

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
self._future_error = None

super().configure(store_addr, replica_id, rank, world_size)
super().configure(store_addr, replica_id, rank, world_size, quorum_id)

def report_future_error(self, e: Exception) -> None:
"""
Expand Down Expand Up @@ -1412,7 +1443,12 @@ def shutdown(self) -> None:
self._p.kill()

def configure(
self, store_addr: str, replica_id: str, rank: int, world_size: int
self,
store_addr: str,
replica_id: str,
rank: int,
world_size: int,
quorum_id: Optional[int] = None,
) -> None:
self._world_size = world_size

Expand Down
Loading