From 1dab310edf88d21008581107627ea7bb3c6f66c5 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Tue, 22 Nov 2022 17:11:39 +0000 Subject: [PATCH 1/4] fix: use trio for async_to_sync --- setup.cfg | 1 + src/firebolt/utils/util.py | 16 +++------------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/setup.cfg b/setup.cfg index 306f6be1da4..4b8a2a95871 100755 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires = pydantic[dotenv]>=1.8.2,<1.10 readerwriterlock==1.0.9 sqlparse>=0.4.2 + trio python_requires = >=3.7 include_package_data = True package_dir = diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index 56c07b07b4f..a59cd421538 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -4,7 +4,7 @@ new_event_loop, set_event_loop, ) -from functools import lru_cache, wraps +from functools import lru_cache, partial, wraps from threading import Thread from typing import ( TYPE_CHECKING, @@ -16,6 +16,7 @@ TypeVar, ) +import trio # type: ignore from httpx import URL T = TypeVar("T") @@ -166,18 +167,7 @@ def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Calla @wraps(f) def sync(*args: Any, **kwargs: Any) -> Any: - try: - loop = get_event_loop() - except RuntimeError: - loop = new_event_loop() - set_event_loop(loop) - # We are inside a running loop - if loop.is_running(): - nonlocal async_job_thread - if not async_job_thread: - async_job_thread = AsyncJobThread() - return async_job_thread.execute(f(*args, **kwargs)) - return loop.run_until_complete(f(*args, **kwargs)) + return trio.run(partial(f, *args, **kwargs)) return sync From d1b99bf4eafc006f431d0cef15bfb50cb5e57029 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 23 Nov 2022 15:58:58 +0000 Subject: [PATCH 2/4] Remove leftover asyncio code --- src/firebolt/async_db/cursor.py | 2 +- src/firebolt/db/connection.py | 9 ++-- src/firebolt/db/cursor.py | 12 ++--- src/firebolt/utils/util.py | 93 ++------------------------------- 4 files changed, 14 insertions(+), 102 deletions(-) diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index bc457b4ba3a..363314a6ace 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -679,7 +679,7 @@ def __exit__( class Cursor(BaseCursor): """ - Executes asyncio queries to Firebolt Database. + Executes async queries to Firebolt Database. Should not be created directly; use :py:func:`connection.cursor ` diff --git a/src/firebolt/db/connection.py b/src/firebolt/db/connection.py index e05b96af581..10fa91309e1 100644 --- a/src/firebolt/db/connection.py +++ b/src/firebolt/db/connection.py @@ -11,7 +11,7 @@ from firebolt.async_db.connection import async_connect_factory from firebolt.db.cursor import Cursor from firebolt.utils.exception import ConnectionClosedError -from firebolt.utils.util import AsyncJobThread, async_to_sync +from firebolt.utils.util import async_to_sync class Connection(AsyncBaseConnection): @@ -31,7 +31,7 @@ class Connection(AsyncBaseConnection): are not implemented. """ - __slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_async_job_thread") + __slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",) cursor_class = Cursor @@ -40,18 +40,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # Holding this lock for write means that connection is closing itself. # cursor() should hold this lock for read to read/write state self._closing_lock = RWLockWrite() - self._async_job_thread = AsyncJobThread() def cursor(self) -> Cursor: with self._closing_lock.gen_rlock(): - c = super()._cursor(async_job_thread=self._async_job_thread) + c = super()._cursor() assert isinstance(c, Cursor) # typecheck return c @wraps(AsyncBaseConnection._aclose) def close(self) -> None: with self._closing_lock.gen_wlock(): - async_to_sync(self._aclose, self._async_job_thread)() + async_to_sync(self._aclose)() # Context manager support def __enter__(self) -> Connection: diff --git a/src/firebolt/db/cursor.py b/src/firebolt/db/cursor.py index 1ba5cde7a84..96e1f976def 100644 --- a/src/firebolt/db/cursor.py +++ b/src/firebolt/db/cursor.py @@ -14,7 +14,7 @@ check_not_closed, check_query_executed, ) -from firebolt.utils.util import AsyncJobThread, async_to_sync +from firebolt.utils.util import async_to_sync class Cursor(AsyncBaseCursor): @@ -34,13 +34,11 @@ class Cursor(AsyncBaseCursor): __slots__ = AsyncBaseCursor.__slots__ + ( "_query_lock", "_idx_lock", - "_async_job_thread", ) def __init__(self, *args: Any, **kwargs: Any) -> None: self._query_lock = RWLockWrite() self._idx_lock = Lock() - self._async_job_thread: AsyncJobThread = kwargs.pop("async_job_thread") super().__init__(*args, **kwargs) @wraps(AsyncBaseCursor.execute) @@ -52,7 +50,7 @@ def execute( async_execution: Optional[bool] = False, ) -> Union[int, str]: with self._query_lock.gen_wlock(): - return async_to_sync(super().execute, self._async_job_thread)( + return async_to_sync(super().execute)( query, parameters, skip_parsing, async_execution ) @@ -64,7 +62,7 @@ def executemany( async_execution: Optional[bool] = False, ) -> Union[int, str]: with self._query_lock.gen_wlock(): - return async_to_sync(super().executemany, self._async_job_thread)( + return async_to_sync(super().executemany)( query, parameters_seq, async_execution ) @@ -106,9 +104,9 @@ def __iter__(self) -> Generator[List[ColType], None, None]: @wraps(AsyncBaseCursor.get_status) def get_status(self, query_id: str) -> QueryStatus: with self._query_lock.gen_rlock(): - return async_to_sync(super().get_status, self._async_job_thread)(query_id) + return async_to_sync(super().get_status)(query_id) @wraps(AsyncBaseCursor.cancel) def cancel(self, query_id: str) -> None: with self._query_lock.gen_rlock(): - return async_to_sync(super().cancel, self._async_job_thread)(query_id) + return async_to_sync(super().cancel)(query_id) diff --git a/src/firebolt/utils/util.py b/src/firebolt/utils/util.py index a59cd421538..e05ae677fb5 100644 --- a/src/firebolt/utils/util.py +++ b/src/firebolt/utils/util.py @@ -1,22 +1,7 @@ -from asyncio import ( - AbstractEventLoop, - get_event_loop, - new_event_loop, - set_event_loop, -) from functools import lru_cache, partial, wraps -from threading import Thread -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Coroutine, - Optional, - Type, - TypeVar, -) - -import trio # type: ignore +from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar + +import trio from httpx import URL T = TypeVar("T") @@ -85,81 +70,11 @@ def fix_url_schema(url: str) -> str: return url if url.startswith("http") else f"https://{url}" -class AsyncJobThread: - """Thread runner that allows running async tasks synchronously in a separate thread. - - Caches loop to be reused in all threads. - It allows running async functions synchronously inside a running event loop. - Since nesting loops is not allowed, we create a separate thread for a new event loop - - Attributes: - result (Any): Value, returned by coroutine execution - exception (Optional[BaseException]): If any, exception that occurred - during coroutine execution - """ - - def __init__(self) -> None: - self._loop: Optional[AbstractEventLoop] = None - self.result: Any = None - self.exception: Optional[BaseException] = None - - def _initialize_loop(self) -> None: - """Initialize a loop once to use for later execution. - - Tries to get a running loop. - Creates a new loop if no active one, and sets it as active. - """ - if not self._loop: - try: - # despite the docs, this function fails if no loop is set - self._loop = get_event_loop() - except RuntimeError: - self._loop = new_event_loop() - set_event_loop(self._loop) - - def _run(self, coro: Coroutine) -> None: - """Run coroutine in an event loop. - - Execution return value is stored into ``result`` field. - If an exception occurs, it will be caught and stored into ``exception`` field. - - Args: - coro (Coroutine): Coroutine to execute - """ - try: - self._initialize_loop() - assert self._loop is not None - self.result = self._loop.run_until_complete(coro) - except BaseException as e: - self.exception = e - - def execute(self, coro: Coroutine) -> Any: - """Execute coroutine in a separate thread. - - Args: - coro (Coroutine): Coroutine to execute - - Returns: - Any: Coroutine execution return value - - Raises: - exception: Exeption, occured within coroutine - """ - thread = Thread(target=self._run, args=[coro]) - thread.start() - thread.join() - if self.exception: - raise self.exception - return self.result - - -def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Callable: +def async_to_sync(f: Callable) -> Callable: """Convert async function to sync. Args: f (Callable): function to convert - async_job_thread (AsyncJobThread): Job thread instance to use for async excution - (Default value = None) Returns: Callable: regular function, which can be executed synchronously From 0604dfda23bbf52fe537f5aa33a89ec22d61d696 Mon Sep 17 00:00:00 2001 From: ptiurin Date: Wed, 23 Nov 2022 18:21:51 +0000 Subject: [PATCH 3/4] Adding test --- setup.cfg | 1 + tests/integration/dbapi/sync/test_queries.py | 38 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/setup.cfg b/setup.cfg index 4b8a2a95871..cb468377baf 100755 --- a/setup.cfg +++ b/setup.cfg @@ -56,6 +56,7 @@ dev = pytest-mock==3.6.1 pytest-timeout==2.1.0 pytest-xdist==2.5.0 + trio-typing[mypy]==0.6.* types-cryptography==3.3.18 [options.package_data] diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index f1c009d9b98..651a1ab4e59 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -464,3 +464,41 @@ async def test_server_side_async_execution_get_status( # assert ( # type(status) is QueryStatus, # ), "get_status() did not return a QueryStatus object." + + +def test_multi_thread_connection_sharing( + engine_url: str, + database_name: str, + password_auth: Auth, + account_name: str, + api_endpoint: str, +) -> None: + + exceptions = [] + + connection = connect( + auth=password_auth, + database=database_name, + account_name=account_name, + engine_url=engine_url, + api_endpoint=api_endpoint, + ) + + def run_query(): + try: + cursor = connection.cursor() + cursor.execute("select 1") + cursor.fetchall() + except BaseException as e: + exceptions.append(e) + + thread_1 = Thread(target=run_query) + thread_2 = Thread(target=run_query) + + thread_1.start() + thread_1.join() + thread_2.start() + thread_2.join() + + connection.close() + assert not exceptions From 7f38604c02f9f3d588bb1ee5315e3ced98591fff Mon Sep 17 00:00:00 2001 From: ptiurin Date: Thu, 24 Nov 2022 14:58:42 +0000 Subject: [PATCH 4/4] Adding test comment --- tests/integration/dbapi/sync/test_queries.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 651a1ab4e59..508f70374e4 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -473,6 +473,12 @@ def test_multi_thread_connection_sharing( account_name: str, api_endpoint: str, ) -> None: + """ + Test to verify sharing the same connection between different + threads works. With asyncio synching an async function this used + to fail due to a different loop having exclusive rights to the + Httpx client. Trio fixes this issue. + """ exceptions = []