Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
cast,
)

import psycopg

from dbos._outcome import Immediate, NoResult, Outcome, Pending
from dbos._utils import GlobalParams, retriable_postgres_exception

Expand Down Expand Up @@ -831,10 +829,10 @@ def record_get_result(func: Callable[[], R]) -> R:
return r

outcome = (
wfOutcome.wrap(init_wf)
wfOutcome.wrap(init_wf, dbos=dbos)
.also(DBOSAssumeRole(rr))
.also(enterWorkflowCtxMgr(attributes))
.then(record_get_result)
.then(record_get_result, dbos=dbos)
)
return outcome() # type: ignore

Expand Down Expand Up @@ -1146,7 +1144,7 @@ def check_existing_result() -> Union[NoResult, R]:

outcome = (
stepOutcome.then(record_step_result)
.intercept(check_existing_result)
.intercept(check_existing_result, dbos=dbos)
.also(EnterDBOSStep(attributes))
)
return outcome()
Expand Down
80 changes: 67 additions & 13 deletions dbos/_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,24 @@
import contextlib
import inspect
import time
from typing import Any, Callable, Coroutine, Optional, Protocol, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Optional,
Protocol,
TypeVar,
Union,
cast,
)

from dbos._context import EnterDBOSStepRetry
from dbos._error import DBOSException
from dbos._registrations import get_dbos_func_name

if TYPE_CHECKING:
from ._dbos import DBOS

T = TypeVar("T")
R = TypeVar("R")
Expand All @@ -24,10 +39,15 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "NoResult":
class Outcome(Protocol[T]):

def wrap(
self, before: Callable[[], Callable[[Callable[[], T]], R]]
self,
before: Callable[[], Callable[[Callable[[], T]], R]],
*,
dbos: Optional["DBOS"] = None,
) -> "Outcome[R]": ...

def then(self, next: Callable[[Callable[[], T]], R]) -> "Outcome[R]": ...
def then(
self, next: Callable[[Callable[[], T]], R], *, dbos: Optional["DBOS"] = None
) -> "Outcome[R]": ...

def also(
self, cm: contextlib.AbstractContextManager[Any, bool]
Expand All @@ -41,7 +61,10 @@ def retry(
) -> "Outcome[T]": ...

def intercept(
self, interceptor: Callable[[], Union[NoResult, T]]
self,
interceptor: Callable[[], Union[NoResult, T]],
*,
dbos: Optional["DBOS"] = None,
) -> "Outcome[T]": ...

def __call__(self) -> Union[T, Coroutine[Any, Any, T]]: ...
Expand All @@ -63,11 +86,17 @@ class Immediate(Outcome[T]):
def __init__(self, func: Callable[[], T]):
self._func = func

def then(self, next: Callable[[Callable[[], T]], R]) -> "Immediate[R]":
def then(
self,
next: Callable[[Callable[[], T]], R],
dbos: Optional["DBOS"] = None,
) -> "Immediate[R]":
return Immediate(lambda: next(self._func))

def wrap(
self, before: Callable[[], Callable[[Callable[[], T]], R]]
self,
before: Callable[[], Callable[[Callable[[], T]], R]],
dbos: Optional["DBOS"] = None,
) -> "Immediate[R]":
return Immediate(lambda: before()(self._func))

Expand All @@ -79,7 +108,10 @@ def _intercept(
return intercepted if not isinstance(intercepted, NoResult) else func()

def intercept(
self, interceptor: Callable[[], Union[NoResult, T]]
self,
interceptor: Callable[[], Union[NoResult, T]],
*,
dbos: Optional["DBOS"] = None,
) -> "Immediate[T]":
return Immediate[T](lambda: Immediate._intercept(self._func, interceptor))

Expand Down Expand Up @@ -142,7 +174,12 @@ def _raise(ex: BaseException) -> T:
async def _wrap(
func: Callable[[], Coroutine[Any, Any, T]],
before: Callable[[], Callable[[Callable[[], T]], R]],
*,
dbos: Optional["DBOS"] = None,
) -> R:
# Make sure the executor pool is configured correctly
if dbos is not None:
await dbos._configure_asyncio_thread_pool()
after = await asyncio.to_thread(before)
try:
value = await func()
Expand All @@ -151,12 +188,17 @@ async def _wrap(
return await asyncio.to_thread(after, lambda: Pending._raise(exp))

def wrap(
self, before: Callable[[], Callable[[Callable[[], T]], R]]
self,
before: Callable[[], Callable[[Callable[[], T]], R]],
*,
dbos: Optional["DBOS"] = None,
) -> "Pending[R]":
return Pending[R](lambda: Pending._wrap(self._func, before))
return Pending[R](lambda: Pending._wrap(self._func, before, dbos=dbos))

def then(self, next: Callable[[Callable[[], T]], R]) -> "Pending[R]":
return Pending[R](lambda: Pending._wrap(self._func, lambda: next))
def then(
self, next: Callable[[Callable[[], T]], R], *, dbos: Optional["DBOS"] = None
) -> "Pending[R]":
return Pending[R](lambda: Pending._wrap(self._func, lambda: next, dbos=dbos))

@staticmethod
async def _also( # type: ignore
Expand All @@ -173,12 +215,24 @@ def also(self, cm: contextlib.AbstractContextManager[Any, bool]) -> "Pending[T]"
async def _intercept(
func: Callable[[], Coroutine[Any, Any, T]],
interceptor: Callable[[], Union[NoResult, T]],
*,
dbos: Optional["DBOS"] = None,
) -> T:
# Make sure the executor pool is configured correctly
if dbos is not None:
await dbos._configure_asyncio_thread_pool()
intercepted = await asyncio.to_thread(interceptor)
return intercepted if not isinstance(intercepted, NoResult) else await func()

def intercept(self, interceptor: Callable[[], Union[NoResult, T]]) -> "Pending[T]":
return Pending[T](lambda: Pending._intercept(self._func, interceptor))
def intercept(
self,
interceptor: Callable[[], Union[NoResult, T]],
*,
dbos: Optional["DBOS"] = None,
) -> "Pending[T]":
return Pending[T](
lambda: Pending._intercept(self._func, interceptor, dbos=dbos)
)

@staticmethod
async def _retry(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,12 +1250,43 @@ def test_workflow(var: str) -> str:
var = "test"
assert test_workflow(var) == var

# Start the workflow asynchornously
wf = dbos.start_workflow(test_workflow, var)
assert wf.get_result() == var

DBOS.destroy()
DBOS(config=config)
DBOS.launch()

assert test_workflow(var) == var

wf = dbos.start_workflow(test_workflow, var)
assert wf.get_result() == var


@pytest.mark.asyncio
async def test_destroy_semantics_async(dbos: DBOS, config: DBOSConfig) -> None:

@DBOS.workflow()
async def test_workflow(var: str) -> str:
return var

var = "test"
assert await test_workflow(var) == var

# Start the workflow asynchornously
wf = await dbos.start_workflow_async(test_workflow, var)
assert await wf.get_result() == var

DBOS.destroy()
DBOS(config=config)
DBOS.launch()

assert await test_workflow(var) == var

wf = await dbos.start_workflow_async(test_workflow, var)
assert await wf.get_result() == var


def test_double_decoration(dbos: DBOS) -> None:
with pytest.raises(
Expand Down