diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index f53c7f9144..569c4998a3 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -1,5 +1,6 @@ from typing import Any, Optional +from anyio._core._eventloop import get_asynclib from async_property import async_cached_property # type: ignore from httpx import URL from httpx import AsyncClient as HttpxAsyncClient @@ -20,6 +21,15 @@ mixin_for, ) +# Explicitly import all available backend not get into +# anyio race condition during backend import +for backend in ("asyncio", "trio"): + try: + get_asynclib(backend) + except ModuleNotFoundError: + # Not all backends might be installed + pass + FireboltClientMixinBase = mixin_for(HttpxClient) # type: Any diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 4900a94ef1..06558bca5e 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -1,11 +1,19 @@ from datetime import date, datetime from decimal import Decimal +from threading import Thread from typing import Any, List from pytest import mark, raises from firebolt.async_db._types import ColType, Column -from firebolt.db import Connection, Cursor, DataError, OperationalError +from firebolt.client.auth import UsernamePassword +from firebolt.db import ( + Connection, + Cursor, + DataError, + OperationalError, + connect, +) def assert_deep_eq(got: Any, expected: Any, msg: str) -> bool: @@ -336,3 +344,47 @@ def test_set_invalid_parameter(connection: Connection): c.execute("set some_invalid_parameter = 1") assert len(c._set_parameters) == 0 + + +# Run test multiple times since the issue is flaky +@mark.parametrize("_", range(5)) +def test_anyio_backend_import_issue( + engine_url: str, + database_name: str, + username: str, + password: str, + account_name: str, + api_endpoint: str, + _: int, +) -> None: + threads_cnt = 3 + requests_cnt = 8 + # collect threads exceptions in an array because they're ignored otherwise + exceptions = [] + + def run_query(idx: int): + nonlocal username, password, database_name, engine_url, account_name, api_endpoint + try: + with connect( + auth=UsernamePassword(username, password), + database=database_name, + account_name=account_name, + engine_url=engine_url, + api_endpoint=api_endpoint, + ) as c: + cursor = c.cursor() + cursor.execute(f"select {idx}") + except BaseException as e: + exceptions.append(e) + + def run_queries_parallel() -> None: + nonlocal requests_cnt + threads = [Thread(target=run_query, args=(i,)) for i in range(requests_cnt)] + [t.start() for t in threads] + [t.join() for t in threads] + + threads = [Thread(target=run_queries_parallel) for _ in range(threads_cnt)] + + [t.start() for t in threads] + [t.join() for t in threads] + assert len(exceptions) == 0, exceptions