diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 85a47a8..f9a65d9 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -25,6 +25,7 @@ jobs: sudo apt-get install -y protobuf-compiler + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 pip install .[dev] -v pip install -r docs/requirements.txt diff --git a/torchft/device_mesh.py b/torchft/device_mesh.py index 959970a..bb25574 100644 --- a/torchft/device_mesh.py +++ b/torchft/device_mesh.py @@ -10,6 +10,7 @@ init_device_mesh, ProcessGroup as BaseProcessGroup, ) +from torch.distributed._mesh_layout import _MeshLayout from torch.distributed.tensor.device_mesh import _mesh_resources from torchft.manager import Manager @@ -69,12 +70,15 @@ def __init__( self.replicate_dim_name: str = mesh_dim_names[replicate_dim] self.parent = parent self.flatten_meshes: Dict[str, DeviceMesh] = {} + self._flatten_mapping: Dict[str, "DeviceMesh"] = {} self._device_type: str if mesh is not None: self._device_type = mesh.device_type + self._layout: _MeshLayout = mesh._layout else: assert parent is not None self._device_type = parent.device_type + self._layout: _MeshLayout = parent._layout self._flatten_mesh_list: tuple[DeviceMesh, ...] = tuple() self._thread_id: Optional[int] = None self._hash: Optional[int] = None diff --git a/torchft/manager.py b/torchft/manager.py index 0212c6c..61ac397 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -88,6 +88,8 @@ # crash if call to quorum fails, all replicas will crash. QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES" +TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE" + T = TypeVar("T") @@ -109,6 +111,17 @@ def get_timeout( return default_timeout_sec +def extract_trailing_digits(s: str) -> int: + """ + Extracts the trailing digits from the end of the string s. + Returns an empty string if no trailing digits are found. + """ + i = len(s) - 1 + while i >= 0 and s[i].isdigit(): + i -= 1 + return int(s[i + 1 :]) if i < len(s) - 1 else 0 + + class WorldSizeMode(Enum): """ This controls the numerics for the job when doing allreduces across replicas @@ -223,6 +236,9 @@ def __init__( self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {} self._user_state_dicts: Dict[str, Callable[[], object]] = {} + self._original_fr_dump_temp_file: Optional[str] = os.environ.get( + TORCH_FR_DUMP_TEMP_FILE_ENV + ) self._replica_id = replica_id # Protects state dict @@ -257,7 +273,7 @@ def __init__( store_port = store_port or int(os.environ["MASTER_PORT"]) self._group_rank: int = rank if rank is not None else int(os.environ["RANK"]) group_rank = self._group_rank - group_world_size = world_size or int(os.environ["WORLD_SIZE"]) + self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size if checkpoint_transport is None: @@ -310,7 +326,7 @@ def __init__( hostname=hostname, bind=bind, store_addr=f"{store_addr}:{store_port}", - world_size=group_world_size, + world_size=self._group_world_size, heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, quorum_retries=self._quorum_retries, @@ -338,6 +354,8 @@ def __init__( self._participating_replica_world_size: int = 0 self._is_state_dict_read_allowed = True + self._update_fr_path() + def allow_state_dict_read(self) -> None: if self._is_state_dict_read_allowed: return @@ -674,16 +692,21 @@ def _async_quorum( self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. try: + self._quorum_id = quorum_id with torch.profiler.record_function("torchft::manager::_pg::configure"): + # Reset GPU state for Flight Recorder if torch.accelerator.is_available(): torch.accelerator.synchronize() + torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore + self._update_fr_path() + self._pg.configure( store_prefixed_addr, self._replica_id if self._replica_id is not None else "0", replica_rank, replica_world_size, + quorum_id, ) - self._quorum_id = quorum_id except Exception as e: self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) @@ -758,6 +781,21 @@ def _async_quorum( else None ) + def _update_fr_path(self) -> None: + if self._original_fr_dump_temp_file is not None: + folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}" + os.makedirs(folder, exist_ok=True) + + filename = ( + self._group_rank + if self._replica_id is None + else ( + extract_trailing_digits(self._replica_id) * self._group_world_size + + self._group_rank + ) + ) + os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{filename}" + def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -1253,14 +1291,19 @@ def _assert_same_stream(self) -> None: def wait(self, timeout: Optional[timedelta] = None) -> bool: self._assert_same_stream() - with get_stream_context(self._stream): - self._work.wait() - self._set_future_callback() + try: + with get_stream_context(self._stream): + self._work.wait() + self._set_future_callback() - with get_stream_context(self._stream): - self._managed_fut_tail.wait() + with get_stream_context(self._stream): + self._managed_fut_tail.wait() - return True + return True + except Exception as e: + self._manager._logger.exception(f"got exception waiting for work {e}") + self._manager.report_error(e) + return False def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: self._assert_same_stream() diff --git a/torchft/process_group.py b/torchft/process_group.py index c462928..6f9e206 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -278,7 +278,12 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: """ This reconfigures the ProcessGroup to use a new store, rank and world size. @@ -408,6 +413,7 @@ def __init__( self._timeout = timeout self._replica_id: str | None = None self._rank: int | None = None + self._quorum_id: int | None = None self.errors_logger: logging.Logger = logging.getLogger("torchft_errors") @@ -419,13 +425,19 @@ def getBackendName(self) -> str: raise NotImplementedError("not implemented") def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: pg = self._pg self._replica_id = replica_id + self._quorum_id = quorum_id self._rank = rank if isinstance(pg, ProcessGroup): - pg.configure(store_addr, replica_id, rank, world_size) + pg.configure(store_addr, replica_id, rank, world_size, quorum_id) return # abort if already initialized @@ -443,6 +455,7 @@ def abort(self, errored: bool = True) -> None: "job_id": os.environ.get("JOB_ID", "unknown"), "replica_id": self._replica_id, "rank": self._rank, + "quorum_id": self._quorum_id, "error": "process_group_abort", }, ) @@ -615,6 +628,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro # pyre-fixme[16]: no attribute ProcessGroupGloo backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout) backend_class._set_sequence_number_for_group() + backend_class.options.global_ranks_in_group = list(range(world_size)) + backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}" pg._register_backend( torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class ) @@ -813,6 +828,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro opts = BaseProcessGroupNCCL.Options() opts.config.blocking = False opts.global_ranks_in_group = list(range(world_size)) + opts.group_name = f"torchft_quorum_{self._quorum_id}" pg = BaseProcessGroup(store, rank, world_size) pg._set_default_backend(ProcessGroup.BackendType.NCCL) @@ -979,7 +995,12 @@ def __init__(self, rank: int, world: int) -> None: self.configure_count = 0 def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self.configure_count += 1 @@ -1138,11 +1159,16 @@ def __init__(self, pg: ProcessGroup) -> None: self._error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure(store_addr, replica_id, rank, world_size, quorum_id) def report_error(self, e: Exception) -> None: """ @@ -1194,11 +1220,16 @@ def __init__(self, pg: ProcessGroup) -> None: self._future_error: Optional[Exception] = None def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._future_error = None - super().configure(store_addr, replica_id, rank, world_size) + super().configure(store_addr, replica_id, rank, world_size, quorum_id) def report_future_error(self, e: Exception) -> None: """ @@ -1412,7 +1443,12 @@ def shutdown(self) -> None: self._p.kill() def configure( - self, store_addr: str, replica_id: str, rank: int, world_size: int + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, ) -> None: self._world_size = world_size