diff --git a/torchft/futures.py b/torchft/futures.py new file mode 100644 index 00000000..3ff98013 --- /dev/null +++ b/torchft/futures.py @@ -0,0 +1,165 @@ +import asyncio +import threading +from datetime import timedelta +from typing import Optional, TypeVar +from unittest.mock import Mock + +from torch.futures import Future + +T = TypeVar("T") + + +class _TimerHandle: + def __init__(self) -> None: + self._lock = threading.Lock() + self._lock.acquire() + self._timer_handle: Optional[asyncio.TimerHandle] = None + + def set_timer(self, timer_handle: asyncio.TimerHandle) -> None: + assert self._lock.locked() + + self._timer_handle = timer_handle + self._lock.release() + + def cancel(self) -> None: + with self._lock: + assert self._timer_handle is not None + self._timer_handle.cancel() + self._timer_handle = None + + +class _TimeoutManager: + """ + This class manages timeouts for futures. It uses a background thread with an + event loop to schedule the timeouts. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._event_loop: Optional[asyncio.AbstractEventLoop] = None + self._event_loop_thread: Optional[threading.Thread] = None + self._next_timer_id = 0 + + def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop: + """ + Start the event loop if it has not already been started. + """ + with self._lock: + if self._event_loop is None: + self._event_loop = asyncio.new_event_loop() + self._event_loop_thread = threading.Thread( + target=self._event_loop.run_forever, + daemon=True, + name="TimeoutManager", + ) + self._event_loop_thread.start() + # pyre-fixme[7]: optional + return self._event_loop + + def shutdown(self) -> None: + """ + Shutdown the event loop and cancel all pending timeouts. + """ + with self._lock: + if self._event_loop is not None: + self._event_loop.call_soon_threadsafe(self._event_loop.stop) + assert self._event_loop_thread is not None + self._event_loop_thread.join() + self._event_loop = None + self._event_loop_thread = None + + def register(self, fut: Future[T], timeout: timedelta) -> Future[T]: + """ + Registers a future that will be cancelled after the specified timeout. + """ + # bypass timeout for mock futures + if isinstance(fut, Mock): + return fut + + loop = self._maybe_start_event_loop() + + # pyre-fixme[29]: Future is not a function + timed_fut: Future[T] = Future() + handle: _TimerHandle = _TimerHandle() + # pyre-fixme[6]: *args + loop.call_soon_threadsafe(self._register, loop, timed_fut, timeout, handle) + + def callback(fut: Future[T]) -> None: + handle.cancel() + try: + timed_fut.set_result(fut.wait()) + except Exception as e: + try: + # this can throw if the future is already done + # pyre-fixme[6]: e is not T + timed_fut.set_exception(e) + except Exception: + pass + + fut.add_done_callback(callback) + return timed_fut + + @classmethod + def _register( + cls, + loop: asyncio.AbstractEventLoop, + fut: Future[T], + timeout: timedelta, + handle: _TimerHandle, + ) -> None: + timer_handle = loop.call_later( + timeout.total_seconds(), + lambda: fut.set_exception( + # pyre-fixme[6]: e is not T + TimeoutError(f"future did not complete within {timeout}") + ), + ) + handle.set_timer(timer_handle) + + +_TIMEOUT_MANAGER = _TimeoutManager() + + +def future_timeout(fut: Future[T], timeout: timedelta) -> Future[T]: + """ + Return a Future that completes with the result of the given Future within + the given timeout or with a TimeoutError. + + Args: + fut: The Future to wait for + timeout: The timeout to wait for the Future to complete + + Returns: + The future with a timeout + """ + return _TIMEOUT_MANAGER.register(fut, timeout) + + +def future_wait(fut: Future[T], timeout: timedelta) -> T: + """ + Wait for a Future to complete up to a timeout. + + Args: + fut: The Future to wait for + timeout: The timeout to wait for the Future to complete + + Returns: + The result of the Future if it completed within the timeout. + + Raises: + TimeoutError if the Future did not complete within the timeout. + Any other exception that occurred in the Future. + """ + + event: threading.Event = threading.Event() + + def callback(fut: Future[T]) -> T: + event.set() + return fut.wait() + + fut = fut.then(callback) + + if not event.wait(timeout=timeout.total_seconds()): + raise TimeoutError(f"future did not complete within {timeout}") + + return fut.wait() diff --git a/torchft/futures_test.py b/torchft/futures_test.py new file mode 100644 index 00000000..b9ea1916 --- /dev/null +++ b/torchft/futures_test.py @@ -0,0 +1,47 @@ +from datetime import timedelta +from unittest import TestCase + +from torch.futures import Future + +from torchft.futures import future_timeout, future_wait + + +class FuturesTest(TestCase): + def test_future_wait(self) -> None: + # pyre-fixme[29]: Future is not a function + fut = Future() + with self.assertRaisesRegex(TimeoutError, "future did not complete within"): + future_wait(fut, timeout=timedelta(seconds=0.01)) + + # pyre-fixme[29]: Future is not a function + fut = Future() + fut.set_result(1) + self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1) + + # pyre-fixme[29]: Future is not a function + fut = Future() + fut.set_exception(RuntimeError("test")) + with self.assertRaisesRegex(RuntimeError, "test"): + future_wait(fut, timeout=timedelta(seconds=1.0)) + + def test_future_timeout(self) -> None: + # pyre-fixme[29]: Future is not a function + fut = Future() + timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01)) + with self.assertRaisesRegex(TimeoutError, "future did not complete within"): + timed_fut.wait() + + def test_future_timeout_result(self) -> None: + # pyre-fixme[29]: Future is not a function + fut = Future() + timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) + fut.set_result(1) + self.assertEqual(timed_fut.wait(), 1) + + def test_future_timeout_exception(self) -> None: + # pyre-fixme[29]: Future is not a function + fut = Future() + timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) + fut.set_exception(RuntimeError("test")) + with self.assertRaisesRegex(RuntimeError, "test"): + timed_fut.wait() diff --git a/torchft/manager.py b/torchft/manager.py index 33c7f28c..3e057518 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -39,6 +39,7 @@ from torch.distributed import ReduceOp, TCPStore from torchft.checkpointing import CheckpointServer +from torchft.futures import future_timeout from torchft.torchft import Manager as _Manager, ManagerClient if TYPE_CHECKING: @@ -168,7 +169,7 @@ def _manager_state_dict() -> Dict[str, T]: self._step = 0 self._quorum_id = -1 - self._errored = False + self._errored: Optional[Exception] = None self._healing = False self._pending_work: List[torch.futures.Future[object]] = [] self._batches_committed = 0 @@ -241,13 +242,13 @@ def callback( except Exception as e: logger.exception(f"got exception in all reduce -- skipping remaining: {e}") - self.report_error() + self.report_error(e) fut = torch.futures.Future() # pyre-fixme[29]: not a function fut.set_result(grad) return fut - def report_error(self) -> None: + def report_error(self, e: Exception) -> None: """ Report an error to the manager. @@ -257,14 +258,14 @@ def report_error(self) -> None: This should be called when an error occurs that leads to a corrupted gradient that needs to be discarded. """ - self._errored = True + self._errored = e - def errored(self) -> bool: + def errored(self) -> Optional[Exception]: """ Get whether an error has occurred. Returns: - whether an error has occurred + The error or None if no error has occured. """ return self._errored @@ -281,6 +282,9 @@ def wrap_future( default: the default value to complete the Future with if an error occurs """ + # add a timeout to the future + fut = future_timeout(fut, self._timeout) + # schedule error handling as a continuation on the Future def callback( fut: torch.futures.Future[T], @@ -291,7 +295,7 @@ def callback( return fut.value() except Exception as e: logger.exception(f"got exception in future -- skipping remaining: {e}") - self.report_error() + self.report_error(e) return default fut = fut.then(callback) @@ -313,7 +317,7 @@ def step(self) -> None: self._step += 1 self._batches_committed += self.num_participants() - self._errored = False + self._errored = None self._healing = False self._ckpt_server.allow_checkpoint(self._step) @@ -428,7 +432,7 @@ def should_commit(self) -> bool: """ for work in self._pending_work: # check at the beginning of since .wait() may trigger errors - if self._errored: + if self._errored is not None: break # We swallow the error at in a future then callback so this will @@ -442,7 +446,7 @@ def should_commit(self) -> bool: self._apply_pending_state_dict() enough_replicas = self.num_participants() >= self._min_replica_size - local_should_commit = enough_replicas and not self._errored + local_should_commit = enough_replicas and self._errored is None should_commit = self._client.should_commit( self._rank, self._step, local_should_commit ) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index a733ab57..38d0d529 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from datetime import timedelta from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch @@ -24,6 +25,7 @@ def _create_manager( use_async_quorum: bool = True, min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, + timeout: timedelta = timedelta(seconds=60), ) -> Manager: pg = create_autospec(ProcessGroup) self.store = TCPStore( @@ -47,6 +49,7 @@ def _create_manager( state_dict=lambda: {}, use_async_quorum=use_async_quorum, world_size_mode=world_size_mode, + timeout=timeout, ) return manager @@ -382,25 +385,44 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: def test_manager_report_error(self, client_mock: MagicMock) -> None: manager = self._create_manager() - self.assertFalse(manager.errored()) - manager.report_error() - self.assertTrue(manager.errored()) + self.assertIsNone(manager.errored()) + e = RuntimeError("some error") + manager.report_error(e) + self.assertIs(manager.errored(), e) @patch("torchft.manager.ManagerClient", autospec=True) def test_manager_wrap_future(self, client_mock: MagicMock) -> None: manager = self._create_manager() - self.assertFalse(manager.errored()) + self.assertIsNone(manager.errored()) fut = torch.futures.Future() # pyre-fixme[29]: not a function wrapped_fut = manager.wrap_future(fut, 2) + self.assertIsNone(manager.errored()) - fut.set_exception(RuntimeError("injected failure")) - + e = RuntimeError("injected failure") + fut.set_exception(e) + self.assertIs(manager.errored(), e) self.assertEqual(wrapped_fut.value(), 2) - self.assertTrue(manager.errored()) + self.assertEqual(manager._pending_work, [wrapped_fut]) + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None: + manager = self._create_manager(timeout=timedelta(seconds=0.01)) + + self.assertFalse(manager.errored()) + + fut = torch.futures.Future() # pyre-fixme[29]: not a function + wrapped_fut = manager.wrap_future(fut, 2) + wrapped_fut.wait() + error = manager.errored() + self.assertIsNotNone(error) + with self.assertRaisesRegex( + TimeoutError, "future did not complete within.*0.01" + ): + raise error + @patch("torchft.manager.ManagerClient", autospec=True) def test_manager_numerics(self, client_mock: MagicMock) -> None: manager = self._create_manager() diff --git a/torchft/process_group.py b/torchft/process_group.py index b94543a4..735aa4eb 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -20,15 +20,11 @@ import threading from abc import ABC from datetime import timedelta -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Type import torch import torch.distributed as dist import torch.multiprocessing as mp -from torch._C._distributed_c10d import ( - _register_process_group, - _unregister_process_group, -) # pyre-fixme[21]: no attribute ProcessGroupNCCL # pyre-fixme[21]: no attribute ProcessGroupGloo @@ -440,10 +436,40 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return _DummyWork(tensors) -class ManagedProcessGroup(ErrorSwallowingProcessGroupWrapper): +class _ManagedWork(Work): + def __init__(self, manager: "Manager", work: Work, default_result: object) -> None: + super().__init__() + + self._manager = manager + self._work = work + self._default_result = default_result + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + try: + if timeout is not None: + self._work.wait(timeout) + else: + self._work.wait() + except Exception as e: + self._manager.report_error(e) + + return True + + def get_future(self) -> Future[object]: + return self._manager.wrap_future(self._work.get_future(), self._default_result) + + +class ManagedProcessGroup(ProcessGroupWrapper): """ This is a wrapper around any ProcessGroup that is managed by a torchft Manager. + + This uses the ProcessGroup that is configured in the Manager. The world size + is dynamic and will report the number of active particpants in the quorum to + the model. + + Any errors will be asynchronously reported to the manager and only successes + will be returned to the caller. """ def __init__(self, manager: "Manager") -> None: @@ -451,18 +477,21 @@ def __init__(self, manager: "Manager") -> None: self._manager = manager - def report_error(self, e: Exception) -> None: - """ - Report an error to this process group. This will cause all future - operations to be skipped until the process group is reconfigured via - ``configure``. + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + if self._manager.errored() is not None: + return _DummyWork(tensors) - Args: - e: exception to report - """ - super().report_error(e) + try: + work = super().allreduce(tensors, opts) + except Exception as e: + self._manager.report_error(e) + return _DummyWork(tensors) - self._manager.report_error() + return _ManagedWork( + self._manager, + work, + tensors, + ) def size(self) -> int: return self._manager.num_participants() diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index a8e807d0..a5f73e04 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -29,6 +29,7 @@ ProcessGroupWrapper, _DummyWork, _ErrorSwallowingWork, + _ManagedWork, extend_device_mesh, ) @@ -238,13 +239,19 @@ def test_error_swallowing_process_group_wrapper(self) -> None: def test_managed_process_group(self) -> None: manager = Mock(spec=Manager) + manager.errored.return_value = None manager._pg = ProcessGroupDummy(0, 1) pg = ManagedProcessGroup(manager) manager.num_participants.return_value = 123 self.assertEqual(pg.size(), 123) - err = RuntimeError("test") - pg.report_error(err) - self.assertEqual(pg.error(), err) - self.assertEqual(manager.report_error.call_count, 1) + t = torch.zeros(10) + work = pg.allreduce([t], ReduceOp.SUM) + self.assertIsInstance(work, _ManagedWork) + work.wait() + fut = work.get_future() + fut.wait() + + self.assertEqual(manager.report_error.call_count, 0) + self.assertEqual(manager.wrap_future.call_count, 1)