Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions torchft/futures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import asyncio
import threading
from datetime import timedelta
from typing import Optional, TypeVar
from unittest.mock import Mock

from torch.futures import Future

T = TypeVar("T")


class _TimerHandle:
def __init__(self) -> None:
self._lock = threading.Lock()
self._lock.acquire()
self._timer_handle: Optional[asyncio.TimerHandle] = None

def set_timer(self, timer_handle: asyncio.TimerHandle) -> None:
assert self._lock.locked()

self._timer_handle = timer_handle
self._lock.release()

def cancel(self) -> None:
with self._lock:
assert self._timer_handle is not None
self._timer_handle.cancel()
self._timer_handle = None


class _TimeoutManager:
"""
This class manages timeouts for futures. It uses a background thread with an
event loop to schedule the timeouts.
"""

def __init__(self) -> None:
self._lock = threading.Lock()
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
self._event_loop_thread: Optional[threading.Thread] = None
self._next_timer_id = 0

def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
"""
Start the event loop if it has not already been started.
"""
with self._lock:
if self._event_loop is None:
self._event_loop = asyncio.new_event_loop()
self._event_loop_thread = threading.Thread(
target=self._event_loop.run_forever,
daemon=True,
name="TimeoutManager",
)
self._event_loop_thread.start()
# pyre-fixme[7]: optional
return self._event_loop

def shutdown(self) -> None:
"""
Shutdown the event loop and cancel all pending timeouts.
"""
with self._lock:
if self._event_loop is not None:
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
assert self._event_loop_thread is not None
self._event_loop_thread.join()
self._event_loop = None
self._event_loop_thread = None

def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
"""
Registers a future that will be cancelled after the specified timeout.
"""
# bypass timeout for mock futures
if isinstance(fut, Mock):
return fut

loop = self._maybe_start_event_loop()

# pyre-fixme[29]: Future is not a function
timed_fut: Future[T] = Future()
handle: _TimerHandle = _TimerHandle()
# pyre-fixme[6]: *args
loop.call_soon_threadsafe(self._register, loop, timed_fut, timeout, handle)

def callback(fut: Future[T]) -> None:
handle.cancel()
try:
timed_fut.set_result(fut.wait())
except Exception as e:
try:
# this can throw if the future is already done
# pyre-fixme[6]: e is not T
timed_fut.set_exception(e)
except Exception:
pass

fut.add_done_callback(callback)
return timed_fut

@classmethod
def _register(
cls,
loop: asyncio.AbstractEventLoop,
fut: Future[T],
timeout: timedelta,
handle: _TimerHandle,
) -> None:
timer_handle = loop.call_later(
timeout.total_seconds(),
lambda: fut.set_exception(
# pyre-fixme[6]: e is not T
TimeoutError(f"future did not complete within {timeout}")
),
)
handle.set_timer(timer_handle)


_TIMEOUT_MANAGER = _TimeoutManager()


def future_timeout(fut: Future[T], timeout: timedelta) -> Future[T]:
"""
Return a Future that completes with the result of the given Future within
the given timeout or with a TimeoutError.

Args:
fut: The Future to wait for
timeout: The timeout to wait for the Future to complete

Returns:
The future with a timeout
"""
return _TIMEOUT_MANAGER.register(fut, timeout)


def future_wait(fut: Future[T], timeout: timedelta) -> T:
"""
Wait for a Future to complete up to a timeout.

Args:
fut: The Future to wait for
timeout: The timeout to wait for the Future to complete

Returns:
The result of the Future if it completed within the timeout.

Raises:
TimeoutError if the Future did not complete within the timeout.
Any other exception that occurred in the Future.
"""

event: threading.Event = threading.Event()

def callback(fut: Future[T]) -> T:
event.set()
return fut.wait()

fut = fut.then(callback)

if not event.wait(timeout=timeout.total_seconds()):
raise TimeoutError(f"future did not complete within {timeout}")

return fut.wait()
47 changes: 47 additions & 0 deletions torchft/futures_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from datetime import timedelta
from unittest import TestCase

from torch.futures import Future

from torchft.futures import future_timeout, future_wait


class FuturesTest(TestCase):
def test_future_wait(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
future_wait(fut, timeout=timedelta(seconds=0.01))

# pyre-fixme[29]: Future is not a function
fut = Future()
fut.set_result(1)
self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1)

# pyre-fixme[29]: Future is not a function
fut = Future()
fut.set_exception(RuntimeError("test"))
with self.assertRaisesRegex(RuntimeError, "test"):
future_wait(fut, timeout=timedelta(seconds=1.0))

def test_future_timeout(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
timed_fut.wait()

def test_future_timeout_result(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
fut.set_result(1)
self.assertEqual(timed_fut.wait(), 1)

def test_future_timeout_exception(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
fut.set_exception(RuntimeError("test"))
with self.assertRaisesRegex(RuntimeError, "test"):
timed_fut.wait()
24 changes: 14 additions & 10 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torch.distributed import ReduceOp, TCPStore

from torchft.checkpointing import CheckpointServer
from torchft.futures import future_timeout
from torchft.torchft import Manager as _Manager, ManagerClient

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,7 +169,7 @@ def _manager_state_dict() -> Dict[str, T]:

self._step = 0
self._quorum_id = -1
self._errored = False
self._errored: Optional[Exception] = None
self._healing = False
self._pending_work: List[torch.futures.Future[object]] = []
self._batches_committed = 0
Expand Down Expand Up @@ -241,13 +242,13 @@ def callback(

except Exception as e:
logger.exception(f"got exception in all reduce -- skipping remaining: {e}")
self.report_error()
self.report_error(e)

fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(grad)
return fut

def report_error(self) -> None:
def report_error(self, e: Exception) -> None:
"""
Report an error to the manager.

Expand All @@ -257,14 +258,14 @@ def report_error(self) -> None:
This should be called when an error occurs that leads to a corrupted
gradient that needs to be discarded.
"""
self._errored = True
self._errored = e

def errored(self) -> bool:
def errored(self) -> Optional[Exception]:
"""
Get whether an error has occurred.

Returns:
whether an error has occurred
The error or None if no error has occured.
"""
return self._errored

Expand All @@ -281,6 +282,9 @@ def wrap_future(
default: the default value to complete the Future with if an error occurs
"""

# add a timeout to the future
fut = future_timeout(fut, self._timeout)

# schedule error handling as a continuation on the Future
def callback(
fut: torch.futures.Future[T],
Expand All @@ -291,7 +295,7 @@ def callback(
return fut.value()
except Exception as e:
logger.exception(f"got exception in future -- skipping remaining: {e}")
self.report_error()
self.report_error(e)
return default

fut = fut.then(callback)
Expand All @@ -313,7 +317,7 @@ def step(self) -> None:
self._step += 1
self._batches_committed += self.num_participants()

self._errored = False
self._errored = None
self._healing = False
self._ckpt_server.allow_checkpoint(self._step)

Expand Down Expand Up @@ -428,7 +432,7 @@ def should_commit(self) -> bool:
"""
for work in self._pending_work:
# check at the beginning of since .wait() may trigger errors
if self._errored:
if self._errored is not None:
break

# We swallow the error at in a future then callback so this will
Expand All @@ -442,7 +446,7 @@ def should_commit(self) -> bool:
self._apply_pending_state_dict()

enough_replicas = self.num_participants() >= self._min_replica_size
local_should_commit = enough_replicas and not self._errored
local_should_commit = enough_replicas and self._errored is None
should_commit = self._client.should_commit(
self._rank, self._step, local_should_commit
)
Expand Down
36 changes: 29 additions & 7 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from datetime import timedelta
from unittest import TestCase
from unittest.mock import MagicMock, create_autospec, patch

Expand All @@ -24,6 +25,7 @@ def _create_manager(
use_async_quorum: bool = True,
min_replica_size: int = 2,
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
timeout: timedelta = timedelta(seconds=60),
) -> Manager:
pg = create_autospec(ProcessGroup)
self.store = TCPStore(
Expand All @@ -47,6 +49,7 @@ def _create_manager(
state_dict=lambda: {},
use_async_quorum=use_async_quorum,
world_size_mode=world_size_mode,
timeout=timeout,
)
return manager

Expand Down Expand Up @@ -382,25 +385,44 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
def test_manager_report_error(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())
manager.report_error()
self.assertTrue(manager.errored())
self.assertIsNone(manager.errored())
e = RuntimeError("some error")
manager.report_error(e)
self.assertIs(manager.errored(), e)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
manager = self._create_manager()

self.assertFalse(manager.errored())
self.assertIsNone(manager.errored())

fut = torch.futures.Future() # pyre-fixme[29]: not a function
wrapped_fut = manager.wrap_future(fut, 2)
self.assertIsNone(manager.errored())

fut.set_exception(RuntimeError("injected failure"))

e = RuntimeError("injected failure")
fut.set_exception(e)
self.assertIs(manager.errored(), e)
self.assertEqual(wrapped_fut.value(), 2)
self.assertTrue(manager.errored())

self.assertEqual(manager._pending_work, [wrapped_fut])

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
manager = self._create_manager(timeout=timedelta(seconds=0.01))

self.assertFalse(manager.errored())

fut = torch.futures.Future() # pyre-fixme[29]: not a function
wrapped_fut = manager.wrap_future(fut, 2)
wrapped_fut.wait()
error = manager.errored()
self.assertIsNotNone(error)
with self.assertRaisesRegex(
TimeoutError, "future did not complete within.*0.01"
):
raise error

@patch("torchft.manager.ManagerClient", autospec=True)
def test_manager_numerics(self, client_mock: MagicMock) -> None:
manager = self._create_manager()
Expand Down
Loading
Loading