From 5039b5651ec7144038bb9b1ea333ab3a7f4cd7e6 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 30 Apr 2025 17:17:41 -0700 Subject: [PATCH] manager: gracefully handle errors from configure+checkpoint --- torchft/manager.py | 139 ++++++++++++++++++++++------------------ torchft/manager_test.py | 74 +++++++++++++++++++++ 2 files changed, 150 insertions(+), 63 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index c00c8257..c3cad7c3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -508,12 +508,16 @@ def _async_quorum( self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}") # We use the replica rank and world as we want all replicas in the PG. - # TODO: handle configure errors - with torch.profiler.record_function("torchft::manager::_pg.configure"): - self._pg.configure( - store_prefixed_addr, replica_rank, replica_world_size - ) - self._quorum_id = quorum_id + try: + with torch.profiler.record_function("torchft::manager::_pg.configure"): + self._pg.configure( + store_prefixed_addr, replica_rank, replica_world_size + ) + self._quorum_id = quorum_id + except Exception as e: + self._logger.exception(f"got exception in pg configure: {e}") + self.report_error(e) + return if allow_heal: # run recovery on the recovery stream if available @@ -523,62 +527,67 @@ def _async_quorum( if recovery_stream is not None else nullcontext() ): - if quorum.recover_dst_ranks: - self._logger.info( - f"peers need recovery from us {quorum.recover_dst_ranks}" - ) - with torch.profiler.record_function( - "torchft::manager::_checkpoint_transport::send_checkpoint" - ): - self._checkpoint_transport.send_checkpoint( - dst_ranks=quorum.recover_dst_ranks, - step=max_step, - state_dict=self._manager_state_dict(), - timeout=self._timeout, + try: + if quorum.recover_dst_ranks: + self._logger.info( + f"peers need recovery from us {quorum.recover_dst_ranks}" ) - - # See manager.rs for healing conditions - if heal: - self._healing = True - self._logger.info( - f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" - ) - primary_client = ManagerClient( - recover_src_manager_address, - connect_timeout=self._connect_timeout, - ) - checkpoint_metadata = primary_client._checkpoint_metadata( - self._rank, timeout=self._timeout - ) - recover_src_rank = quorum.recover_src_rank - assert ( - recover_src_rank is not None - ), "must have a recover rank when healing" - - self._logger.info( - f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" - ) - - # we apply the user state dict only when safe from the main thread - # save it for now - with torch.profiler.record_function( - "torchft::manager::_checkpoint_transport::recv_checkpoint" - ): - self._pending_state_dict = ( - self._checkpoint_transport.recv_checkpoint( - src_rank=recover_src_rank, - metadata=checkpoint_metadata, + with torch.profiler.record_function( + "torchft::manager::_checkpoint_transport::send_checkpoint" + ): + self._checkpoint_transport.send_checkpoint( + dst_ranks=quorum.recover_dst_ranks, step=max_step, + state_dict=self._manager_state_dict(), timeout=self._timeout, ) + + # See manager.rs for healing conditions + if heal: + self._healing = True + self._logger.info( + f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" ) + primary_client = ManagerClient( + recover_src_manager_address, + connect_timeout=self._connect_timeout, + ) + checkpoint_metadata = primary_client._checkpoint_metadata( + self._rank, timeout=self._timeout + ) + recover_src_rank = quorum.recover_src_rank + assert ( + recover_src_rank is not None + ), "must have a recover rank when healing" - # pyre-fixme[6]: got object - self.load_state_dict(self._pending_state_dict["torchft"]) + self._logger.info( + f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" + ) - # This isn't strictly needed as loading the state_dict above should - # restore the correct step but it makes writing tests simpler. - self._step = max_step + # we apply the user state dict only when safe from the main thread + # save it for now + with torch.profiler.record_function( + "torchft::manager::_checkpoint_transport::recv_checkpoint" + ): + self._pending_state_dict = ( + self._checkpoint_transport.recv_checkpoint( + src_rank=recover_src_rank, + metadata=checkpoint_metadata, + step=max_step, + timeout=self._timeout, + ) + ) + + # pyre-fixme[6]: got object + self.load_state_dict(self._pending_state_dict["torchft"]) + + # This isn't strictly needed as loading the state_dict above should + # restore the correct step but it makes writing tests simpler. + self._step = max_step + except Exception as e: + self._logger.exception(f"got exception in recovery: {e}") + self.report_error(e) + return def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -587,15 +596,19 @@ def _apply_pending_state_dict(self) -> None: assert self._quorum_future is not None, "must call step before should_commit" self._quorum_future.result() - self._logger.info("applying pending state dict") + pending_state_dict = self._pending_state_dict - assert self._pending_state_dict is not None, "checkpoint was not staged" - assert ( - self._load_state_dict is not None - ), "user load_state_dict is not initialized." - self._load_state_dict(self._pending_state_dict["user"]) - self._pending_state_dict = None - self._logger.info("Loaded state dict.") + if pending_state_dict is None: + assert self.errored(), "checkpoint was not staged and no error occured" + else: + self._logger.info("applying pending state dict") + + assert ( + self._load_state_dict is not None + ), "user load_state_dict is not initialized." + self._load_state_dict(pending_state_dict["user"]) + self._pending_state_dict = None + self._logger.info("Loaded state dict.") @torch.profiler.record_function("torchft::manager::should_commit") def should_commit(self, timeout: Optional[timedelta] = None) -> bool: diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 362545f1..be2dec27 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -14,6 +14,7 @@ from torch.distributed import TCPStore from torchft._torchft import QuorumResult +from torchft.checkpointing.transport import CheckpointTransport from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup, _DummyWork @@ -648,6 +649,79 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None: manager.start_quorum() self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=True) + client_mock().should_commit = MagicMock(return_value=False) + + transport = MagicMock(spec=CheckpointTransport) + transport.send_checkpoint.side_effect = RuntimeError("send failure") + transport.recv_checkpoint.side_effect = RuntimeError("recv failure") + manager._checkpoint_transport = transport + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 2 + quorum.heal = True + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + manager.wait_quorum() + self.assertFalse(manager.should_commit()) + + error = manager.errored() + self.assertIsNotNone(error) + with self.assertRaisesRegex(RuntimeError, "recv failure"): + raise error + + quorum.recover_dst_ranks = [0] + manager.start_quorum() + manager.wait_quorum() + self.assertFalse(manager.should_commit()) + + error = manager.errored() + self.assertIsNotNone(error) + with self.assertRaisesRegex(RuntimeError, "send failure"): + raise error + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_configure_errors(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=True) + client_mock().should_commit = MagicMock(return_value=False) + + # pyre-ignore[16]: mock + manager._pg.configure.side_effect = RuntimeError("configure failure") + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 2 + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + manager.wait_quorum() + self.assertFalse(manager.should_commit()) + + error = manager.errored() + self.assertIsNotNone(error) + with self.assertRaisesRegex(RuntimeError, "configure failure"): + raise error + @patch("torchft.manager.ManagerClient", autospec=True) def test_max_retries(self, client_mock: MagicMock) -> None: # Create a manager with max_retries=2