From b22119c874ad00f3b20a48deca4a5035358c6bec Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 17 Oct 2025 10:19:21 -0700 Subject: [PATCH 1/2] handle exception waiting for work (#287) Summary: work.wait() can throw so wrap that in a try/catch to handle it gracefully by reporting error to the manager, leading the should_commit to fail Differential Revision: D84880993 --- torchft/device_mesh.py | 1 + torchft/manager.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/torchft/device_mesh.py b/torchft/device_mesh.py index 959970a..1e9861a 100644 --- a/torchft/device_mesh.py +++ b/torchft/device_mesh.py @@ -69,6 +69,7 @@ def __init__( self.replicate_dim_name: str = mesh_dim_names[replicate_dim] self.parent = parent self.flatten_meshes: Dict[str, DeviceMesh] = {} + self._flatten_mapping: Dict[str, "DeviceMesh"] = {} self._device_type: str if mesh is not None: self._device_type = mesh.device_type diff --git a/torchft/manager.py b/torchft/manager.py index 4fc8a83..2cc57cb 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -1244,14 +1244,19 @@ def _assert_same_stream(self) -> None: def wait(self, timeout: Optional[timedelta] = None) -> bool: self._assert_same_stream() - with get_stream_context(self._stream): - self._work.wait() - self._set_future_callback() + try: + with get_stream_context(self._stream): + self._work.wait() + self._set_future_callback() - with get_stream_context(self._stream): - self._managed_fut_tail.wait() + with get_stream_context(self._stream): + self._managed_fut_tail.wait() - return True + return True + except Exception as e: + self._manager._logger.exception(f"got exception waiting for work {e}") + self._manager.report_error(e) + return False def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: self._assert_same_stream() From 229e4fa5016b1d9305205ac4e3c82902639afa05 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 17 Oct 2025 10:19:21 -0700 Subject: [PATCH 2/2] reset flight recorder trace (#283) Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum Differential Revision: D84260745 --- torchft/manager.py | 44 +++++++++++++++++++++++++++++--- torchft/process_group.py | 54 +++++++++++++++++++++++++++++++++------- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 2cc57cb..69742c5 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -88,6 +88,8 @@ # crash if call to quorum fails, all replicas will crash. QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES" +TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE" + T = TypeVar("T") @@ -109,6 +111,17 @@ def get_timeout( return default_timeout_sec +def extract_trailing_digits(s: str) -> int: + """ + Extracts the trailing digits from the end of the string s. + Returns an empty string if no trailing digits are found. + """ + i = len(s) - 1 + while i >= 0 and s[i].isdigit(): + i -= 1 + return int(s[i + 1 :]) if i < len(s) - 1 else 0 + + class WorldSizeMode(Enum): """ This controls the numerics for the job when doing allreduces across replicas @@ -223,6 +236,9 @@ def __init__( self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} self._user_state_dicts: Dict[str, Callable[[], object]] = {} + self._original_fr_dump_temp_file: Optional[str] = os.environ.get( + TORCH_FR_DUMP_TEMP_FILE_ENV + ) self._replica_id = replica_id # Protects state dict @@ -257,7 +273,7 @@ def __init__( store_port = store_port or int(os.environ["MASTER_PORT"]) self._group_rank: int = rank if rank is not None else int(os.environ["RANK"]) group_rank = self._group_rank - group_world_size = world_size or int(os.environ["WORLD_SIZE"]) + self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size if checkpoint_transport is None: @@ -310,7 +326,7 @@ def __init__( hostname=hostname, bind=bind, store_addr=f"{store_addr}:{store_port}", - world_size=group_world_size, + world_size=self._group_world_size, heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, quorum_retries=self._quorum_retries, @@ -338,6 +354,8 @@ def __init__( self._participating_replica_world_size: int = 0 self._is_state_dict_read_allowed = True + self._update_fr_path() + def allow_state_dict_read(self) -> None: if self._is_state_dict_read_allowed: return @@ -665,16 +683,21 @@ def _async_quorum( self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. try: + self._quorum_id = quorum_id with torch.profiler.record_function("torchft::manager::_pg::configure"): + # Reset GPU state for Flight Recorder if torch.accelerator.is_available(): torch.accelerator.synchronize() + torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore + self._update_fr_path() + self._pg.configure( store_prefixed_addr, self._replica_id if self._replica_id is not None else "0", replica_rank, replica_world_size, + quorum_id, ) - self._quorum_id = quorum_id except Exception as e: self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) @@ -749,6 +772,21 @@ def _async_quorum( else None ) + def _update_fr_path(self) -> None: + if self._original_fr_dump_temp_file is not None: + folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}" + os.makedirs(folder, exist_ok=True) + + filename = ( + self._group_rank + if self._replica_id is None + else ( + extract_trailing_digits(self._replica_id) * self._group_world_size + + self._group_rank + ) + ) + os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{filename}" + def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" diff --git a/torchft/process_group.py b/torchft/process_group.py index c462928..6f9e206 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: """ This reconfigures the ProcessGroup to use a new store, rank and world size. @@ -408,6 +413,7 @@ def __init__( self._timeout = timeout self._replica_id: str | None = None self._rank: int | None = None + self._quorum_id: int | None = None self.errors_logger: logging.Logger = logging.getLogger("torchft_errors") @@ -419,13 +425,19 @@ def getBackendName(self) -> str: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: pg = self._pg self._replica_id = replica_id + self._quorum_id = quorum_id self._rank = rank if isinstance(pg, ProcessGroup): - pg.configure(store_addr, replica_id, rank, world_size) + pg.configure(store_addr, replica_id, rank, world_size, quorum_id) return # abort if already initialized @@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None: "job_id": os.environ.get("JOB_ID", "unknown"), "replica_id": self._replica_id, "rank": self._rank, + "quorum_id": self._quorum_id, "error": "process_group_abort", }, ) @@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro # pyre-fixme[16]: no attribute ProcessGroupGloo backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout) backend_class._set_sequence_number_for_group() + backend_class.options.global_ranks_in_group = list(range(world_size)) + backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}" pg._register_backend( torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class ) @@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro opts = BaseProcessGroupNCCL.Options() opts.config.blocking = False opts.global_ranks_in_group = list(range(world_size)) + opts.group_name = f"torchft_quorum_{self._quorum_id}" pg = BaseProcessGroup(store, rank, world_size) pg._set_default_backend(ProcessGroup.BackendType.NCCL) @@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None: self.configure_count = 0 def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self.configure_count += 1 @@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None: self._error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure(store_addr, replica_id, rank, world_size, quorum_id) def report_error(self, e: Exception) -> None: """ @@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None: self._future_error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._future_error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure(store_addr, replica_id, rank, world_size, quorum_id) def report_future_error(self, e: Exception) -> None: """ @@ -1412,7 +1443,12 @@ def shutdown(self) -> None: self._p.kill() def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._world_size = world_size