Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ install_requires =
pydantic[dotenv]>=1.8.2,<1.10
readerwriterlock==1.0.9
sqlparse>=0.4.2
trio
python_requires = >=3.7
include_package_data = True
package_dir =
Expand All @@ -55,6 +56,7 @@ dev =
pytest-mock==3.6.1
pytest-timeout==2.1.0
pytest-xdist==2.5.0
trio-typing[mypy]==0.6.*
types-cryptography==3.3.18

[options.package_data]
Expand Down
2 changes: 1 addition & 1 deletion src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def __exit__(

class Cursor(BaseCursor):
"""
Executes asyncio queries to Firebolt Database.
Executes async queries to Firebolt Database.
Should not be created directly;
use :py:func:`connection.cursor <firebolt.async_db.connection.Connection>`

Expand Down
9 changes: 4 additions & 5 deletions src/firebolt/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from firebolt.async_db.connection import async_connect_factory
from firebolt.db.cursor import Cursor
from firebolt.utils.exception import ConnectionClosedError
from firebolt.utils.util import AsyncJobThread, async_to_sync
from firebolt.utils.util import async_to_sync


class Connection(AsyncBaseConnection):
Expand All @@ -31,7 +31,7 @@ class Connection(AsyncBaseConnection):
are not implemented.
"""

__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock", "_async_job_thread")
__slots__ = AsyncBaseConnection.__slots__ + ("_closing_lock",)

cursor_class = Cursor

Expand All @@ -40,18 +40,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Holding this lock for write means that connection is closing itself.
# cursor() should hold this lock for read to read/write state
self._closing_lock = RWLockWrite()
self._async_job_thread = AsyncJobThread()

def cursor(self) -> Cursor:
with self._closing_lock.gen_rlock():
c = super()._cursor(async_job_thread=self._async_job_thread)
c = super()._cursor()
assert isinstance(c, Cursor) # typecheck
return c

@wraps(AsyncBaseConnection._aclose)
def close(self) -> None:
with self._closing_lock.gen_wlock():
async_to_sync(self._aclose, self._async_job_thread)()
async_to_sync(self._aclose)()

# Context manager support
def __enter__(self) -> Connection:
Expand Down
12 changes: 5 additions & 7 deletions src/firebolt/db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
check_not_closed,
check_query_executed,
)
from firebolt.utils.util import AsyncJobThread, async_to_sync
from firebolt.utils.util import async_to_sync


class Cursor(AsyncBaseCursor):
Expand All @@ -34,13 +34,11 @@ class Cursor(AsyncBaseCursor):
__slots__ = AsyncBaseCursor.__slots__ + (
"_query_lock",
"_idx_lock",
"_async_job_thread",
)

def __init__(self, *args: Any, **kwargs: Any) -> None:
self._query_lock = RWLockWrite()
self._idx_lock = Lock()
self._async_job_thread: AsyncJobThread = kwargs.pop("async_job_thread")
super().__init__(*args, **kwargs)

@wraps(AsyncBaseCursor.execute)
Expand All @@ -52,7 +50,7 @@ def execute(
async_execution: Optional[bool] = False,
) -> Union[int, str]:
with self._query_lock.gen_wlock():
return async_to_sync(super().execute, self._async_job_thread)(
return async_to_sync(super().execute)(
query, parameters, skip_parsing, async_execution
)

Expand All @@ -64,7 +62,7 @@ def executemany(
async_execution: Optional[bool] = False,
) -> Union[int, str]:
with self._query_lock.gen_wlock():
return async_to_sync(super().executemany, self._async_job_thread)(
return async_to_sync(super().executemany)(
query, parameters_seq, async_execution
)

Expand Down Expand Up @@ -106,9 +104,9 @@ def __iter__(self) -> Generator[List[ColType], None, None]:
@wraps(AsyncBaseCursor.get_status)
def get_status(self, query_id: str) -> QueryStatus:
with self._query_lock.gen_rlock():
return async_to_sync(super().get_status, self._async_job_thread)(query_id)
return async_to_sync(super().get_status)(query_id)

@wraps(AsyncBaseCursor.cancel)
def cancel(self, query_id: str) -> None:
with self._query_lock.gen_rlock():
return async_to_sync(super().cancel, self._async_job_thread)(query_id)
return async_to_sync(super().cancel)(query_id)
105 changes: 5 additions & 100 deletions src/firebolt/utils/util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
from asyncio import (
AbstractEventLoop,
get_event_loop,
new_event_loop,
set_event_loop,
)
from functools import lru_cache, wraps
from threading import Thread
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Optional,
Type,
TypeVar,
)
from functools import lru_cache, partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar

import trio
from httpx import URL

T = TypeVar("T")
Expand Down Expand Up @@ -84,100 +70,19 @@ def fix_url_schema(url: str) -> str:
return url if url.startswith("http") else f"https://{url}"


class AsyncJobThread:
"""Thread runner that allows running async tasks synchronously in a separate thread.

Caches loop to be reused in all threads.
It allows running async functions synchronously inside a running event loop.
Since nesting loops is not allowed, we create a separate thread for a new event loop

Attributes:
result (Any): Value, returned by coroutine execution
exception (Optional[BaseException]): If any, exception that occurred
during coroutine execution
"""

def __init__(self) -> None:
self._loop: Optional[AbstractEventLoop] = None
self.result: Any = None
self.exception: Optional[BaseException] = None

def _initialize_loop(self) -> None:
"""Initialize a loop once to use for later execution.

Tries to get a running loop.
Creates a new loop if no active one, and sets it as active.
"""
if not self._loop:
try:
# despite the docs, this function fails if no loop is set
self._loop = get_event_loop()
except RuntimeError:
self._loop = new_event_loop()
set_event_loop(self._loop)

def _run(self, coro: Coroutine) -> None:
"""Run coroutine in an event loop.

Execution return value is stored into ``result`` field.
If an exception occurs, it will be caught and stored into ``exception`` field.

Args:
coro (Coroutine): Coroutine to execute
"""
try:
self._initialize_loop()
assert self._loop is not None
self.result = self._loop.run_until_complete(coro)
except BaseException as e:
self.exception = e

def execute(self, coro: Coroutine) -> Any:
"""Execute coroutine in a separate thread.

Args:
coro (Coroutine): Coroutine to execute

Returns:
Any: Coroutine execution return value

Raises:
exception: Exeption, occured within coroutine
"""
thread = Thread(target=self._run, args=[coro])
thread.start()
thread.join()
if self.exception:
raise self.exception
return self.result


def async_to_sync(f: Callable, async_job_thread: AsyncJobThread = None) -> Callable:
def async_to_sync(f: Callable) -> Callable:
"""Convert async function to sync.

Args:
f (Callable): function to convert
async_job_thread (AsyncJobThread): Job thread instance to use for async excution
(Default value = None)

Returns:
Callable: regular function, which can be executed synchronously
"""

@wraps(f)
def sync(*args: Any, **kwargs: Any) -> Any:
try:
loop = get_event_loop()
except RuntimeError:
loop = new_event_loop()
set_event_loop(loop)
# We are inside a running loop
if loop.is_running():
nonlocal async_job_thread
if not async_job_thread:
async_job_thread = AsyncJobThread()
return async_job_thread.execute(f(*args, **kwargs))
return loop.run_until_complete(f(*args, **kwargs))
return trio.run(partial(f, *args, **kwargs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


return sync

Expand Down
44 changes: 44 additions & 0 deletions tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,47 @@ async def test_server_side_async_execution_get_status(
# assert (
# type(status) is QueryStatus,
# ), "get_status() did not return a QueryStatus object."


def test_multi_thread_connection_sharing(
engine_url: str,
database_name: str,
password_auth: Auth,
account_name: str,
api_endpoint: str,
) -> None:
"""
Test to verify sharing the same connection between different
threads works. With asyncio synching an async function this used
to fail due to a different loop having exclusive rights to the
Httpx client. Trio fixes this issue.
"""

exceptions = []

connection = connect(
auth=password_auth,
database=database_name,
account_name=account_name,
engine_url=engine_url,
api_endpoint=api_endpoint,
)

def run_query():
try:
cursor = connection.cursor()
cursor.execute("select 1")
cursor.fetchall()
except BaseException as e:
exceptions.append(e)

thread_1 = Thread(target=run_query)
thread_2 = Thread(target=run_query)

thread_1.start()
thread_1.join()
thread_2.start()
thread_2.join()

connection.close()
assert not exceptions