Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/firebolt/client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Optional

from anyio._core._eventloop import get_asynclib
from async_property import async_cached_property # type: ignore
from httpx import URL
from httpx import AsyncClient as HttpxAsyncClient
Expand All @@ -20,6 +21,15 @@
mixin_for,
)

# Explicitly import all available backend not get into
# anyio race condition during backend import
for backend in ("asyncio", "trio"):
try:
get_asynclib(backend)
except ModuleNotFoundError:
# Not all backends might be installed
pass
Comment on lines +26 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just worried since we're using a "private" API for this hack it might change without notice. As we're pinning the httpx version we're fine for now, but if we were to upgrade it might cause issues. Hopefully it will be picked up by nightly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will investigate how this could be properly fixed


FireboltClientMixinBase = mixin_for(HttpxClient) # type: Any


Expand Down
54 changes: 53 additions & 1 deletion tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from datetime import date, datetime
from decimal import Decimal
from threading import Thread
from typing import Any, List

from pytest import mark, raises

from firebolt.async_db._types import ColType, Column
from firebolt.db import Connection, Cursor, DataError, OperationalError
from firebolt.client.auth import UsernamePassword
from firebolt.db import (
Connection,
Cursor,
DataError,
OperationalError,
connect,
)


def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool:
Expand Down Expand Up @@ -336,3 +344,47 @@ def test_set_invalid_parameter(connection: Connection):
c.execute("set some_invalid_parameter = 1")

assert len(c._set_parameters) == 0


# Run test multiple times since the issue is flaky
@mark.parametrize("_", range(5))
def test_anyio_backend_import_issue(
engine_url: str,
database_name: str,
username: str,
password: str,
account_name: str,
api_endpoint: str,
_: int,
) -> None:
threads_cnt = 3
requests_cnt = 8
# collect threads exceptions in an array because they're ignored otherwise
exceptions = []

def run_query(idx: int):
nonlocal username, password, database_name, engine_url, account_name, api_endpoint
try:
with connect(
auth=UsernamePassword(username, password),
database=database_name,
account_name=account_name,
engine_url=engine_url,
api_endpoint=api_endpoint,
) as c:
cursor = c.cursor()
cursor.execute(f"select {idx}")
except BaseException as e:
exceptions.append(e)

def run_queries_parallel() -> None:
nonlocal requests_cnt
threads = [Thread(target=run_query, args=(i,)) for i in range(requests_cnt)]
[t.start() for t in threads]
[t.join() for t in threads]

threads = [Thread(target=run_queries_parallel) for _ in range(threads_cnt)]

[t.start() for t in threads]
[t.join() for t in threads]
assert len(exceptions) == 0, exceptions