diff --git a/src/firebolt/common/util.py b/src/firebolt/common/util.py index 39f76fdd0d3..102193395b4 100644 --- a/src/firebolt/common/util.py +++ b/src/firebolt/common/util.py @@ -1,4 +1,4 @@ -from asyncio import get_event_loop +from asyncio import get_event_loop, new_event_loop from functools import lru_cache, wraps from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar @@ -39,6 +39,15 @@ def fix_url_schema(url: str) -> str: def async_to_sync(f: Callable) -> Callable: @wraps(f) def sync(*args: Any, **kwargs: Any) -> Any: - return get_event_loop().run_until_complete(f(*args, **kwargs)) + close = False + try: + loop = get_event_loop() + except RuntimeError: + loop = new_event_loop() + close = True + res = loop.run_until_complete(f(*args, **kwargs)) + if close: + loop.close() + return res return sync diff --git a/tests/unit/common/__init__.py b/tests/unit/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/common/test_util.py b/tests/unit/common/test_util.py new file mode 100644 index 00000000000..57437754a30 --- /dev/null +++ b/tests/unit/common/test_util.py @@ -0,0 +1,52 @@ +from asyncio import run +from threading import Thread + +from pytest import raises + +from firebolt.common.util import async_to_sync + + +def test_async_to_sync_happy_path(): + """async_to_sync properly converts coroutine to sync function""" + + class JobMarker(Exception): + pass + + async def task(): + raise JobMarker() + + for i in range(3): + with raises(JobMarker): + async_to_sync(task)() + + +def test_async_to_sync_thread(): + """async_to_sync properly works in threads""" + + marks = [False] * 3 + + async def task(id: int): + marks[id] = True + + ts = [Thread(target=async_to_sync(task), args=[i]) for i in range(3)] + [t.start() for t in ts] + [t.join() for t in ts] + assert all(marks) + + +def test_async_to_sync_after_run(): + """async_to_sync properly runs after asyncio.run""" + + class JobMarker(Exception): + pass + + async def task(): + raise JobMarker() + + with raises(JobMarker): + run(task()) + + # Here local event loop is closed by run + + with raises(JobMarker): + async_to_sync(task)()