Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docsrc/Connecting_and_queries.rst
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ In addition, server-side asynchronous queries can be cancelled calling ``cancel(
**Returns**: ``CANCELED_EXECUTION``


Thread safety
==============================

Thread safety is set to 2, meaning it's safe to share the module and
:ref:`Connection <firebolt.db:Connection>` object across threads.
:ref:`Cursor <firebolt.db:Cursor>` is a lightweight object that should be instantiated
by calling ``connection.cursor()`` within a thread and should not be shared across different threads.
Similarly, in an asynchronous context the Cursor obejct should not be shared across tasks
as it will lead to a nondeterministic data returned. Follow the best practice from the
:ref:`connecting_and_queries:Running multiple queries in parallel`.


Using DATE and DATETIME values
==============================
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ install_requires =
python-dateutil>=2.8.2
readerwriterlock>=1.0.9
sqlparse>=0.4.2
tricycle>=0.2.2
trio>=0.22.0
python_requires = >=3.7
include_package_data = True
Expand Down
28 changes: 10 additions & 18 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)

from httpx import Response, codes
from tricycle import RWLock

from firebolt.async_db.util import is_db_available, is_engine_running
from firebolt.client import AsyncClient
Expand Down Expand Up @@ -69,8 +68,6 @@ class Cursor(BaseCursor):

"""

__slots__ = BaseCursor.__slots__ + ("_async_query_lock",)

def __init__(
self,
*args: Any,
Expand All @@ -79,7 +76,6 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._async_query_lock = RWLock()
self._client = client
self.connection = connection

Expand Down Expand Up @@ -392,29 +388,25 @@ async def cancel(self, query_id: str) -> None:

@wraps(BaseCursor.fetchone)
async def fetchone(self) -> Optional[List[ColType]]:
async with self._async_query_lock.read_locked():
"""Fetch the next row of a query result set."""
return super().fetchone()
"""Fetch the next row of a query result set."""
return super().fetchone()

@wraps(BaseCursor.fetchmany)
async def fetchmany(self, size: Optional[int] = None) -> List[List[ColType]]:
async with self._async_query_lock.read_locked():
"""
Fetch the next set of rows of a query result;
size is cursor.arraysize by default.
"""
return super().fetchmany(size)
"""
Fetch the next set of rows of a query result;
size is cursor.arraysize by default.
"""
return super().fetchmany(size)

@wraps(BaseCursor.fetchall)
async def fetchall(self) -> List[List[ColType]]:
async with self._async_query_lock.read_locked():
"""Fetch all remaining rows of a query result."""
return super().fetchall()
"""Fetch all remaining rows of a query result."""
return super().fetchall()

@wraps(BaseCursor.nextset)
async def nextset(self) -> None:
async with self._async_query_lock.read_locked():
return super().nextset()
return super().nextset()

@check_not_closed
def __enter__(self) -> Cursor:
Expand Down
31 changes: 30 additions & 1 deletion src/firebolt/client/auth/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from time import time
from typing import Generator, Optional
from typing import AsyncGenerator, Generator, Optional

from httpx import Auth as HttpxAuth
from httpx import Request, Response, codes
from trio import Lock

from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.util import Timer, cached_property
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self, use_token_cache: bool = True):
self._use_token_cache = use_token_cache
self._token: Optional[str] = self._get_cached_token()
self._expires: Optional[int] = None
self._lock = Lock()

def copy(self) -> "Auth":
"""Make another auth object with same credentials.
Expand Down Expand Up @@ -120,3 +122,30 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]:
yield from self.get_new_token_generator()
request.headers["Authorization"] = f"Bearer {self.token}"
yield request

async def async_auth_flow(
self, request: Request
) -> AsyncGenerator[Request, Response]:
"""
Execute the authentication flow asynchronously.

Overridden in order to lock and ensure no more than
one authentication request is sent at a time. This
avoids excessive load on the auth server.
"""
if self.requires_request_body:
await request.aread()

async with self._lock:
flow = self.auth_flow(request)
request = next(flow)

while True:
response = yield request
if self.requires_response_body:
await response.aread()

try:
request = flow.send(response)
except StopIteration:
break
58 changes: 54 additions & 4 deletions tests/unit/client/test_client_async.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from re import Pattern
from typing import Callable
from re import Pattern, compile
from types import MethodType
from typing import Any, Callable

from httpx import codes
from httpx import Request, Response, codes
from pytest import raises
from pytest_httpx import HTTPXMock
from trio import open_nursery, sleep

from firebolt.client import AsyncClient
from firebolt.client.auth import Auth
from firebolt.client.auth import Auth, ClientCredentials
from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL
from firebolt.utils.util import fix_url_schema

Expand Down Expand Up @@ -110,3 +112,51 @@ async def test_client_account_id(
api_endpoint=server,
) as c:
assert await c.account_id == account_id, "Invalid account id returned."


async def test_concurent_auth_lock(
httpx_mock: HTTPXMock,
account_name: str,
server: str,
client_id: str,
client_secret: str,
access_token: str,
auth_url: str,
check_token_callback: Callable,
) -> None:
CONCURENT_COUNT = 10
url = "https://url"

checked_creds_times = 0

async def mock_send_handling_redirects(self, *args: Any, **kwargs: Any) -> Response:
# simulate network delay so the context switches
await sleep(0.01)
return await AsyncClient._send_handling_redirects(self, *args, **kwargs)

def check_credentials(
request: Request = None,
**kwargs,
) -> Response:
nonlocal checked_creds_times
checked_creds_times += 1
return Response(
status_code=codes.OK,
json={"expires_in": 2**32, "access_token": access_token},
)

httpx_mock.add_callback(check_token_callback, url=compile(f"{url}/."))
httpx_mock.add_callback(check_credentials, url=auth_url)

async with AsyncClient(
auth=ClientCredentials(client_id, client_secret),
api_endpoint=server,
account_name=account_name,
) as c:
c._send_handling_redirects = MethodType(mock_send_handling_redirects, c)
urls = [f"{url}/{i}" for i in range(CONCURENT_COUNT)]
async with open_nursery() as nursery:
for url in urls:
nursery.start_soon(c.get, url)

assert checked_creds_times == 1