diff --git a/docsrc/Connecting_and_queries.rst b/docsrc/Connecting_and_queries.rst index 0536ff6d50f..c81a19b1d78 100644 --- a/docsrc/Connecting_and_queries.rst +++ b/docsrc/Connecting_and_queries.rst @@ -386,6 +386,17 @@ In addition, server-side asynchronous queries can be cancelled calling ``cancel( **Returns**: ``CANCELED_EXECUTION`` +Thread safety +============================== + +Thread safety is set to 2, meaning it's safe to share the module and +:ref:`Connection ` object across threads. +:ref:`Cursor ` is a lightweight object that should be instantiated +by calling ``connection.cursor()`` within a thread and should not be shared across different threads. +Similarly, in an asynchronous context the Cursor obejct should not be shared across tasks +as it will lead to a nondeterministic data returned. Follow the best practice from the +:ref:`connecting_and_queries:Running multiple queries in parallel`. + Using DATE and DATETIME values ============================== diff --git a/setup.cfg b/setup.cfg index d4029615406..f6d3ed41a08 100755 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,6 @@ install_requires = python-dateutil>=2.8.2 readerwriterlock>=1.0.9 sqlparse>=0.4.2 - tricycle>=0.2.2 trio>=0.22.0 python_requires = >=3.7 include_package_data = True diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index a5348597422..fa1dc4993d1 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -18,7 +18,6 @@ ) from httpx import Response, codes -from tricycle import RWLock from firebolt.async_db.util import is_db_available, is_engine_running from firebolt.common._types import ( @@ -98,8 +97,6 @@ class Cursor(BaseCursor): """ - __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - def __init__( self, *args: Any, @@ -108,7 +105,6 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._async_query_lock = RWLock() self._client = client self.connection = connection @@ -438,26 +434,22 @@ async def cancel(self, query_id: str) -> None: @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: - async with self._async_query_lock.read_locked(): - """Fetch the next row of a query result set.""" - return super().fetchone() + """Fetch the next row of a query result set.""" + return super().fetchone() @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - async with self._async_query_lock.read_locked(): - """ - Fetch the next set of rows of a query result; - size is cursor.arraysize by default. - """ - return super().fetchmany(size) + """ + Fetch the next set of rows of a query result; + size is cursor.arraysize by default. + """ + return super().fetchmany(size) @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: - async with self._async_query_lock.read_locked(): - """Fetch all remaining rows of a query result.""" - return super().fetchall() + """Fetch all remaining rows of a query result.""" + return super().fetchall() @wraps(BaseCursor.nextset) async def nextset(self) -> None: - async with self._async_query_lock.read_locked(): - return super().nextset() + return super().nextset() diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index 75d6b29d816..50bac64565c 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -1,6 +1,7 @@ from time import time -from typing import Generator, Optional +from typing import AsyncGenerator, Generator, Optional +from anyio import Lock from httpx import Auth as HttpxAuth from httpx import Request, Response, codes @@ -35,6 +36,7 @@ def __init__(self, use_token_cache: bool = True): self._use_token_cache = use_token_cache self._token: Optional[str] = self._get_cached_token() self._expires: Optional[int] = None + self._lock = Lock() def copy(self) -> "Auth": """Make another auth object with same credentials. @@ -119,3 +121,30 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield from self.get_new_token_generator() request.headers["Authorization"] = f"Bearer {self.token}" yield request + + async def async_auth_flow( + self, request: Request + ) -> AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + Overridden in order to lock and ensure no more than + one authentication request is sent at a time. This + avoids excessive load on the auth server. + """ + if self.requires_request_body: + await request.aread() + + async with self._lock: + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 6da6cf46445..a649fea7e1d 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,3 +1,4 @@ +from asyncio import run from re import Pattern from typing import Callable, List from unittest.mock import patch @@ -422,3 +423,34 @@ async def test_connect_no_user_agent( ) as connection: await connection.cursor().execute("select*") ut.assert_called_once_with([], []) + + +def test_from_asyncio( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + query_url: str, + settings: Settings, + db_name: str, +): + async def async_flow() -> None: + async with ( + await connect( + engine_url=settings.server, + database=db_name, + username="u", + password="p", + account_name=settings.account_name, + api_endpoint=settings.server, + ) + ) as connection: + cursor = connection.cursor() + await cursor.execute("SELECT 1") + await cursor.fetchone() + await cursor.fetchmany(1) + await cursor.fetchall() + + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + run(async_flow()) diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 81d513727fc..d1f3392ae21 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -1,9 +1,11 @@ -from re import Pattern -from typing import Callable +from re import Pattern, compile +from types import MethodType +from typing import Any, Callable -from httpx import codes +from httpx import Request, Response, codes from pytest import raises from pytest_httpx import HTTPXMock +from trio import open_nursery, sleep from firebolt.client import DEFAULT_API_URL, AsyncClient from firebolt.client.auth import Token, UsernamePassword @@ -117,3 +119,49 @@ async def test_client_account_id( api_endpoint=settings.server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." + + +async def test_concurent_auth_lock( + httpx_mock: HTTPXMock, + server: str, + test_username: str, + test_password: str, + test_token: str, + auth_url: str, + check_token_callback: Callable, +) -> None: + CONCURENT_COUNT = 10 + url = "https://url" + + checked_creds_times = 0 + + async def mock_send_handling_redirects(self, *args: Any, **kwargs: Any) -> Response: + # simulate network delay so the context switches + await sleep(0.01) + return await AsyncClient._send_handling_redirects(self, *args, **kwargs) + + def check_credentials( + request: Request = None, + **kwargs, + ) -> Response: + nonlocal checked_creds_times + checked_creds_times += 1 + return Response( + status_code=codes.OK, + json={"expires_in": 2**32, "access_token": test_token}, + ) + + httpx_mock.add_callback(check_token_callback, url=compile(f"{url}/.")) + httpx_mock.add_callback(check_credentials, url=auth_url) + + async with AsyncClient( + auth=UsernamePassword(test_username, test_password), + api_endpoint=server, + ) as c: + c._send_handling_redirects = MethodType(mock_send_handling_redirects, c) + urls = [f"{url}/{i}" for i in range(CONCURENT_COUNT)] + async with open_nursery() as nursery: + for url in urls: + nursery.start_soon(c.get, url) + + assert checked_creds_times == 1