Skip to content

Commit 32f3312

Browse files
authored
Use to_thread for sync transaction functions (#454)
Otherwise, the sync transaction functions will block the entire event loop.
1 parent 381dc4b commit 32f3312

File tree

5 files changed

+121
-12
lines changed

5 files changed

+121
-12
lines changed

dbos/_core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
DBOSWorkflowConflictIDError,
5757
DBOSWorkflowFunctionNotFoundError,
5858
)
59+
from ._logger import dbos_logger
5960
from ._registrations import (
6061
DEFAULT_MAX_RECOVERY_ATTEMPTS,
6162
get_config_name,
@@ -96,6 +97,14 @@
9697
TEMP_SEND_WF_NAME = "<temp>.temp_send_workflow"
9798

9899

100+
def check_is_in_coroutine() -> bool:
101+
try:
102+
asyncio.get_running_loop()
103+
return True
104+
except RuntimeError:
105+
return False
106+
107+
99108
class WorkflowHandleFuture(Generic[R]):
100109

101110
def __init__(self, workflow_id: str, future: Future[R], dbos: "DBOS"):
@@ -828,6 +837,11 @@ def record_get_result(func: Callable[[], R]) -> R:
828837
dbos._sys_db.record_get_result(workflow_id, serialized_r, None)
829838
return r
830839

840+
if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
841+
dbos_logger.warning(
842+
f"Sync workflow ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
843+
)
844+
831845
outcome = (
832846
wfOutcome.wrap(init_wf, dbos=dbos)
833847
.also(DBOSAssumeRole(rr))
@@ -1009,6 +1023,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
10091023
assert (
10101024
ctx.is_workflow()
10111025
), "Transactions must be called from within workflows"
1026+
if check_is_in_coroutine():
1027+
dbos_logger.warning(
1028+
f"Transaction function ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Use asyncio.to_thread instead."
1029+
)
10121030
with DBOSAssumeRole(rr):
10131031
return invoke_tx(*args, **kwargs)
10141032
else:
@@ -1153,6 +1171,10 @@ def check_existing_result() -> Union[NoResult, R]:
11531171

11541172
@wraps(func)
11551173
def wrapper(*args: Any, **kwargs: Any) -> Any:
1174+
if check_is_in_coroutine() and not inspect.iscoroutinefunction(func):
1175+
dbos_logger.warning(
1176+
f"Sync step ({get_dbos_func_name(func)}) shouldn't be invoked from within another async function. Define it as async or use asyncio.to_thread instead."
1177+
)
11561178
# If the step is called from a workflow, run it as a step.
11571179
# Otherwise, run it as a normal function.
11581180
ctx = get_local_dbos_context()

tests/test_async.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import pytest
77
import sqlalchemy as sa
8+
from opentelemetry._logs import set_logger_provider
9+
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
10+
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor, InMemoryLogExporter
811

912
# Public API
1013
from dbos import (
@@ -19,6 +22,8 @@
1922
from dbos._dbos import WorkflowHandle
2023
from dbos._dbos_config import ConfigFile
2124
from dbos._error import DBOSAwaitedWorkflowCancelledError, DBOSException
25+
from dbos._logger import dbos_logger
26+
from dbos._registrations import get_dbos_func_name
2227

2328

2429
@pytest.mark.asyncio
@@ -31,7 +36,7 @@ async def test_async_workflow(dbos: DBOS) -> None:
3136
async def test_workflow(var1: str, var2: str) -> str:
3237
nonlocal wf_counter
3338
wf_counter += 1
34-
res1 = test_transaction(var1)
39+
res1 = await asyncio.to_thread(test_transaction, var1)
3540
res2 = await test_step(var2)
3641
DBOS.logger.info("I'm test_workflow")
3742
return res1 + res2
@@ -88,7 +93,7 @@ async def test_async_step(dbos: DBOS) -> None:
8893
async def test_workflow(var1: str, var2: str) -> str:
8994
nonlocal wf_counter
9095
wf_counter += 1
91-
res1 = test_transaction(var1)
96+
res1 = await asyncio.to_thread(test_transaction, var1)
9297
res2 = await test_step(var2)
9398
DBOS.logger.info("I'm test_workflow")
9499
return res1 + res2
@@ -325,6 +330,7 @@ def test_async_tx_raises(config: ConfigFile) -> None:
325330
async def test_async_tx() -> None:
326331
pass
327332

333+
assert "is a coroutine function" in str(exc_info.value)
328334
# destroy call needed to avoid "functions were registered but DBOS() was not called" warning
329335
DBOS.destroy(destroy_registry=True)
330336

@@ -343,12 +349,12 @@ async def test_workflow(var1: str, var2: str) -> str:
343349
wf_el_id = id(asyncio.get_running_loop())
344350
nonlocal wf_counter
345351
wf_counter += 1
346-
res2 = test_step(var2)
352+
res2 = await test_step(var2)
347353
DBOS.logger.info("I'm test_workflow")
348354
return var1 + res2
349355

350356
@DBOS.step()
351-
def test_step(var: str) -> str:
357+
async def test_step(var: str) -> str:
352358
nonlocal step_el_id
353359
step_el_id = id(asyncio.get_running_loop())
354360
nonlocal step_counter
@@ -605,3 +611,83 @@ async def run_workflow_task() -> str:
605611
# Verify the workflow completes despite the task cancellation
606612
handle: WorkflowHandleAsync[str] = await DBOS.retrieve_workflow_async(wfid)
607613
assert await handle.get_result() == "completed"
614+
615+
616+
@pytest.mark.asyncio
617+
async def test_check_async_violation(dbos: DBOS) -> None:
618+
# Set up in-memory log exporter
619+
log_exporter = InMemoryLogExporter() # type: ignore
620+
log_processor = BatchLogRecordProcessor(log_exporter)
621+
log_provider = LoggerProvider()
622+
log_provider.add_log_record_processor(log_processor)
623+
set_logger_provider(log_provider)
624+
dbos_logger.addHandler(LoggingHandler(logger_provider=log_provider))
625+
626+
@DBOS.workflow()
627+
def sync_workflow() -> str:
628+
return "sync"
629+
630+
@DBOS.step()
631+
def sync_step() -> str:
632+
return "step"
633+
634+
@DBOS.workflow()
635+
async def async_workflow_sync_step() -> str:
636+
return sync_step()
637+
638+
@DBOS.transaction()
639+
def sync_transaction() -> str:
640+
return "txn"
641+
642+
@DBOS.workflow()
643+
async def async_workflow_sync_txn() -> str:
644+
return sync_transaction()
645+
646+
# Call a sync workflow should log a warning
647+
sync_workflow()
648+
649+
log_processor.force_flush(timeout_millis=5000)
650+
logs = log_exporter.get_finished_logs()
651+
assert len(logs) == 1
652+
assert (
653+
logs[0].log_record.body is not None
654+
and f"Sync workflow ({get_dbos_func_name(sync_workflow)}) shouldn't be invoked from within another async function."
655+
in logs[0].log_record.body
656+
)
657+
log_exporter.clear()
658+
659+
# Call a sync step from within an async workflow should log a warning
660+
await async_workflow_sync_step()
661+
log_processor.force_flush(timeout_millis=5000)
662+
logs = log_exporter.get_finished_logs()
663+
assert len(logs) == 1
664+
assert (
665+
logs[0].log_record.body is not None
666+
and f"Sync step ({get_dbos_func_name(sync_step)}) shouldn't be invoked from within another async function."
667+
in logs[0].log_record.body
668+
)
669+
log_exporter.clear()
670+
671+
# Directly call a sync step should log a warning
672+
sync_step()
673+
log_processor.force_flush(timeout_millis=5000)
674+
logs = log_exporter.get_finished_logs()
675+
assert len(logs) == 1
676+
assert (
677+
logs[0].log_record.body is not None
678+
and f"Sync step ({get_dbos_func_name(sync_step)}) shouldn't be invoked from within another async function."
679+
in logs[0].log_record.body
680+
)
681+
log_exporter.clear()
682+
683+
# Call a sync transaction from within an async workflow should log a warning
684+
await async_workflow_sync_txn()
685+
log_processor.force_flush(timeout_millis=5000)
686+
logs = log_exporter.get_finished_logs()
687+
assert len(logs) == 1
688+
assert (
689+
logs[0].log_record.body is not None
690+
and f"Transaction function ({get_dbos_func_name(sync_transaction)}) shouldn't be invoked from within another async function."
691+
in logs[0].log_record.body
692+
)
693+
log_exporter.clear()

tests/test_async_workflow_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def step_three(x: int) -> int:
148148

149149
wfid = str(uuid.uuid4())
150150
with SetWorkflowID(wfid):
151-
assert simple_workflow(input_val) == output
151+
assert await asyncio.to_thread(simple_workflow, input_val) == output
152152

153153
assert step_one_count == 1
154154
assert step_two_count == 1

tests/test_dbos.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: disable-error-code="no-redef"
22

3+
import asyncio
34
import datetime
45
import logging
56
import os
@@ -1668,17 +1669,17 @@ async def async_step(x: int) -> int:
16681669
assert DBOS.workflow_id is None
16691670
return x
16701671

1671-
assert step(5) == 5
1672+
assert await asyncio.to_thread(step, 5) == 5
16721673
assert await async_step(5) == 5
16731674

16741675
DBOS(config=config)
16751676

1676-
assert step(5) == 5
1677+
assert await asyncio.to_thread(step, 5) == 5
16771678
assert await async_step(5) == 5
16781679

16791680
DBOS.launch()
16801681

1681-
assert step(5) == 5
1682+
assert await asyncio.to_thread(step, 5) == 5
16821683
assert await async_step(5) == 5
16831684

16841685
assert len(DBOS.list_workflows()) == 0

tests/test_workflow_introspection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -828,18 +828,18 @@ async def test_callchild_first_asyncio(dbos: DBOS) -> None:
828828
async def parentWorkflow() -> str:
829829
handle = await dbos.start_workflow_async(child_workflow)
830830
child_id = await handle.get_result()
831-
stepOne()
832-
stepTwo()
831+
await stepOne()
832+
await stepTwo()
833833
return child_id
834834

835835
@DBOS.step()
836-
def stepOne() -> str:
836+
async def stepOne() -> str:
837837
workflow_id = DBOS.workflow_id
838838
assert workflow_id is not None
839839
return workflow_id
840840

841841
@DBOS.step()
842-
def stepTwo() -> None:
842+
async def stepTwo() -> None:
843843
return
844844

845845
@DBOS.workflow()

0 commit comments

Comments
 (0)