diff --git a/python/monarch/future.py b/python/monarch/future.py index f65c87011..a13971d12 100644 --- a/python/monarch/future.py +++ b/python/monarch/future.py @@ -5,21 +5,72 @@ # LICENSE file in the root directory of this source tree. import asyncio -from typing import Generator, Generic, TypeVar +from functools import partial +from typing import Generator, Generic, Optional, TypeVar R = TypeVar("R") +def _incomplete(impl, self): + try: + return self._set_result(impl()) + except Exception as e: + self._set_exception(e) + raise + + +async def _aincomplete(impl, self): + try: + return self._set_result(await impl()) + except Exception as e: + self._set_exception(e) + raise + + # TODO: consolidate with monarch.common.future class ActorFuture(Generic[R]): def __init__(self, impl, blocking_impl=None): - self._impl = impl - self._blocking_impl = blocking_impl + if blocking_impl is None: + blocking_impl = partial(asyncio.run, impl()) + self._get = partial(_incomplete, blocking_impl) + self._aget = partial(_aincomplete, impl) - def get(self) -> R: - if self._blocking_impl is not None: - return self._blocking_impl() - return asyncio.run(self._impl()) + def get(self, timeout: Optional[float] = None) -> R: + if timeout is not None: + return asyncio.run(asyncio.wait_for(self._aget(self), timeout)) + return self._get(self) def __await__(self) -> Generator[R, None, R]: - return self._impl().__await__() + return self._aget(self).__await__() + + def _set_result(self, result): + def f(self): + return result + + async def af(self): + return result + + self._get, self._aget = f, af + return result + + def _set_exception(self, e): + def f(self): + raise e + + async def af(self): + raise e + + self._get, self._aget = f, af + + # compatibility with old tensor engine Future objects + # hopefully we do not need done(), add_callback because + # they are harder to implement right. + def result(self, timeout: Optional[float] = None) -> R: + return self.get(timeout) + + def exception(self, timeout: Optional[float] = None): + try: + self.get(timeout) + return None + except Exception as e: + return e diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 12f81047a..00718adc6 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -29,6 +29,7 @@ MonarchContext, ) from monarch.debugger import init_debugging +from monarch.future import ActorFuture from monarch.mesh_controller import spawn_tensor_engine @@ -672,3 +673,100 @@ async def test_async_concurrency(): # actually concurrently processing messages. await am.no_more.call() await fut + + +async def awaitit(f): + return await f + + +def test_actor_future(): + v = 0 + + async def incr(): + nonlocal v + v += 1 + return v + + # can use async implementation from sync + # if no non-blocking is provided + f = ActorFuture(incr) + assert f.get() == 1 + assert v == 1 + assert f.get() == 1 + assert asyncio.run(awaitit(f)) == 1 + + f = ActorFuture(incr) + assert asyncio.run(awaitit(f)) == 2 + assert f.get() == 2 + + def incr2(): + nonlocal v + v += 2 + return v + + # Use non-blocking optimization if provided + f = ActorFuture(incr, incr2) + assert f.get() == 4 + assert asyncio.run(awaitit(f)) == 4 + + async def nope(): + nonlocal v + v += 1 + raise ValueError("nope") + + f = ActorFuture(nope) + + with pytest.raises(ValueError): + f.get() + + assert v == 5 + + with pytest.raises(ValueError): + f.get() + + assert v == 5 + + with pytest.raises(ValueError): + asyncio.run(awaitit(f)) + + assert v == 5 + + def nope(): + nonlocal v + v += 1 + raise ValueError("nope") + + f = ActorFuture(incr, nope) + + with pytest.raises(ValueError): + f.get() + + assert v == 6 + + with pytest.raises(ValueError): + f.result() + + assert f.exception() is not None + + assert v == 6 + + with pytest.raises(ValueError): + asyncio.run(awaitit(f)) + + assert v == 6 + + async def seven(): + return 7 + + f = ActorFuture(seven) + + assert 7 == f.get(timeout=0.001) + + async def neverfinish(): + f = asyncio.Future() + await f + + f = ActorFuture(neverfinish) + + with pytest.raises(asyncio.exceptions.TimeoutError): + f.get(timeout=0.1)