diff --git a/docsrc/Connecting_and_queries.rst b/docsrc/Connecting_and_queries.rst index 5f9a3da1915..94de9681056 100644 --- a/docsrc/Connecting_and_queries.rst +++ b/docsrc/Connecting_and_queries.rst @@ -571,6 +571,17 @@ In addition, server-side asynchronous queries can be cancelled calling ``cancel( **Returns**: ``CANCELED_EXECUTION`` +Thread safety +============================== + +Thread safety is set to 2, meaning it's safe to share the module and +:ref:`Connection ` object across threads. +:ref:`Cursor ` is a lightweight object that should be instantiated +by calling ``connection.cursor()`` within a thread and should not be shared across different threads. +Similarly, in an asynchronous context the Cursor obejct should not be shared across tasks +as it will lead to a nondeterministic data returned. Follow the best practice from the +:ref:`connecting_and_queries:Running multiple queries in parallel`. + Using DATE and DATETIME values ============================== diff --git a/setup.cfg b/setup.cfg index cc7dae3c4af..c6f3339eb8d 100755 --- a/setup.cfg +++ b/setup.cfg @@ -34,7 +34,6 @@ install_requires = python-dateutil>=2.8.2 readerwriterlock>=1.0.9 sqlparse>=0.4.2 - tricycle>=0.2.2 trio>=0.22.0 python_requires = >=3.7 include_package_data = True diff --git a/src/firebolt/async_db/cursor.py b/src/firebolt/async_db/cursor.py index 7d2d5e25a33..3e8d108c670 100644 --- a/src/firebolt/async_db/cursor.py +++ b/src/firebolt/async_db/cursor.py @@ -17,7 +17,6 @@ ) from httpx import Response, codes -from tricycle import RWLock from firebolt.async_db.util import is_db_available, is_engine_running from firebolt.client import AsyncClient @@ -69,8 +68,6 @@ class Cursor(BaseCursor): """ - __slots__ = BaseCursor.__slots__ + ("_async_query_lock",) - def __init__( self, *args: Any, @@ -79,7 +76,6 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - self._async_query_lock = RWLock() self._client = client self.connection = connection @@ -392,29 +388,25 @@ async def cancel(self, query_id: str) -> None: @wraps(BaseCursor.fetchone) async def fetchone(self) -> Optional[List[ColType]]: - async with self._async_query_lock.read_locked(): - """Fetch the next row of a query result set.""" - return super().fetchone() + """Fetch the next row of a query result set.""" + return super().fetchone() @wraps(BaseCursor.fetchmany) async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]: - async with self._async_query_lock.read_locked(): - """ - Fetch the next set of rows of a query result; - size is cursor.arraysize by default. - """ - return super().fetchmany(size) + """ + Fetch the next set of rows of a query result; + size is cursor.arraysize by default. + """ + return super().fetchmany(size) @wraps(BaseCursor.fetchall) async def fetchall(self) -> List[List[ColType]]: - async with self._async_query_lock.read_locked(): - """Fetch all remaining rows of a query result.""" - return super().fetchall() + """Fetch all remaining rows of a query result.""" + return super().fetchall() @wraps(BaseCursor.nextset) async def nextset(self) -> None: - async with self._async_query_lock.read_locked(): - return super().nextset() + return super().nextset() @check_not_closed def __enter__(self) -> Cursor: diff --git a/src/firebolt/client/auth/base.py b/src/firebolt/client/auth/base.py index cacd286913e..a4650f0cf21 100644 --- a/src/firebolt/client/auth/base.py +++ b/src/firebolt/client/auth/base.py @@ -1,8 +1,9 @@ from time import time -from typing import Generator, Optional +from typing import AsyncGenerator, Generator, Optional from httpx import Auth as HttpxAuth from httpx import Request, Response, codes +from trio import Lock from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.util import Timer, cached_property @@ -35,6 +36,7 @@ def __init__(self, use_token_cache: bool = True): self._use_token_cache = use_token_cache self._token: Optional[str] = self._get_cached_token() self._expires: Optional[int] = None + self._lock = Lock() def copy(self) -> "Auth": """Make another auth object with same credentials. @@ -120,3 +122,30 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield from self.get_new_token_generator() request.headers["Authorization"] = f"Bearer {self.token}" yield request + + async def async_auth_flow( + self, request: Request + ) -> AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + Overridden in order to lock and ensure no more than + one authentication request is sent at a time. This + avoids excessive load on the auth server. + """ + if self.requires_request_body: + await request.aread() + + async with self._lock: + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + await response.aread() + + try: + request = flow.send(response) + except StopIteration: + break diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index a1eefddc970..cfc1da203a1 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -1,12 +1,14 @@ -from re import Pattern -from typing import Callable +from re import Pattern, compile +from types import MethodType +from typing import Any, Callable -from httpx import codes +from httpx import Request, Response, codes from pytest import raises from pytest_httpx import HTTPXMock +from trio import open_nursery, sleep from firebolt.client import AsyncClient -from firebolt.client.auth import Auth +from firebolt.client.auth import Auth, ClientCredentials from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema @@ -110,3 +112,51 @@ async def test_client_account_id( api_endpoint=server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." + + +async def test_concurent_auth_lock( + httpx_mock: HTTPXMock, + account_name: str, + server: str, + client_id: str, + client_secret: str, + access_token: str, + auth_url: str, + check_token_callback: Callable, +) -> None: + CONCURENT_COUNT = 10 + url = "https://url" + + checked_creds_times = 0 + + async def mock_send_handling_redirects(self, *args: Any, **kwargs: Any) -> Response: + # simulate network delay so the context switches + await sleep(0.01) + return await AsyncClient._send_handling_redirects(self, *args, **kwargs) + + def check_credentials( + request: Request = None, + **kwargs, + ) -> Response: + nonlocal checked_creds_times + checked_creds_times += 1 + return Response( + status_code=codes.OK, + json={"expires_in": 2**32, "access_token": access_token}, + ) + + httpx_mock.add_callback(check_token_callback, url=compile(f"{url}/.")) + httpx_mock.add_callback(check_credentials, url=auth_url) + + async with AsyncClient( + auth=ClientCredentials(client_id, client_secret), + api_endpoint=server, + account_name=account_name, + ) as c: + c._send_handling_redirects = MethodType(mock_send_handling_redirects, c) + urls = [f"{url}/{i}" for i in range(CONCURENT_COUNT)] + async with open_nursery() as nursery: + for url in urls: + nursery.start_soon(c.get, url) + + assert checked_creds_times == 1