Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/firebolt/common/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import get_event_loop
from asyncio import get_event_loop, new_event_loop
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar

Expand Down Expand Up @@ -39,6 +39,15 @@ def fix_url_schema(url: str) -> str:
def async_to_sync(f: Callable) -> Callable:
@wraps(f)
def sync(*args: Any, **kwargs: Any) -> Any:
return get_event_loop().run_until_complete(f(*args, **kwargs))
close = False
try:
loop = get_event_loop()
except RuntimeError:
loop = new_event_loop()
close = True
res = loop.run_until_complete(f(*args, **kwargs))
if close:
loop.close()
return res

return sync
Empty file added tests/unit/common/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions tests/unit/common/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from asyncio import run
from threading import Thread

from pytest import raises

from firebolt.common.util import async_to_sync


def test_async_to_sync_happy_path():
"""async_to_sync properly converts coroutine to sync function"""

class JobMarker(Exception):
pass

async def task():
raise JobMarker()

for i in range(3):
with raises(JobMarker):
async_to_sync(task)()


def test_async_to_sync_thread():
"""async_to_sync properly works in threads"""

marks = [False] * 3

async def task(id: int):
marks[id] = True

ts = [Thread(target=async_to_sync(task), args=[i]) for i in range(3)]
[t.start() for t in ts]
[t.join() for t in ts]
assert all(marks)


def test_async_to_sync_after_run():
"""async_to_sync properly runs after asyncio.run"""

class JobMarker(Exception):
pass

async def task():
raise JobMarker()

with raises(JobMarker):
run(task())

# Here local event loop is closed by run

with raises(JobMarker):
async_to_sync(task)()