diff --git a/.github/workflows/integration-tests-v2.yml b/.github/workflows/integration-tests-v2.yml index 67b1c20595..03a76ecd2b 100644 --- a/.github/workflows/integration-tests-v2.yml +++ b/.github/workflows/integration-tests-v2.yml @@ -31,7 +31,7 @@ jobs: with: firebolt-client-id: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }} firebolt-client-secret: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }} - account: ${{ vars.FIREBOLT_ACCOUNT }} + account: ${{ vars.FIREBOLT_ACCOUNT_V1 }} api-endpoint: "api.staging.firebolt.io" - name: Run integration tests @@ -42,6 +42,7 @@ jobs: ENGINE_NAME: ${{ steps.setup.outputs.engine_name }} STOPPED_ENGINE_NAME: ${{ steps.setup.outputs.stopped_engine_name }} API_ENDPOINT: "api.staging.firebolt.io" - ACCOUNT_NAME: ${{ vars.FIREBOLT_ACCOUNT }} + ACCOUNT_NAME_V1: ${{ vars.FIREBOLT_ACCOUNT_V1 }} + ACCOUNT_NAME_V2: ${{ vars.FIREBOLT_ACCOUNT_V2 }} run: | pytest -n 6 --dist loadgroup --timeout_method "signal" -o log_cli=true -o log_cli_level=WARNING tests/integration -k "not V1" --runslow diff --git a/src/firebolt/client/client.py b/src/firebolt/client/client.py index a90380cceb..04116951bd 100644 --- a/src/firebolt/client/client.py +++ b/src/firebolt/client/client.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from collections import namedtuple from json import JSONDecodeError from typing import Any, Dict, Optional @@ -44,6 +45,8 @@ FireboltClientMixinBase = mixin_for(HttpxClient) # type: Any +_AccountInfo = namedtuple("_AccountInfo", ["id", "version"]) + class FireboltClientMixin(FireboltClientMixinBase): """HttpxAsyncClient mixin with Firebolt authentication functionality.""" @@ -100,11 +103,16 @@ class Client(FireboltClientMixin, HttpxClient, metaclass=ABCMeta): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs, transport=KeepaliveTransport()) - @cached_property + @property @abstractmethod def account_id(self) -> str: ... + @property + @abstractmethod + def _account_version(self) -> int: + ... + class ClientV2(Client): """An HTTP client, based on httpx.Client. @@ -131,6 +139,35 @@ def __init__( ) @cached_property + def _account_info(self) -> _AccountInfo: + response = self.get( + url=self._api_endpoint.copy_with( + path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) + ) + ) + if response.status_code == HttpxCodes.NOT_FOUND: + assert self.account_name is not None + raise AccountNotFoundOrNoAccessError(self.account_name) + # process all other status codes + response.raise_for_status() + account_id = response.json()["id"] + # If no version assume 1 + account_version = int(response.json().get("infraVersion", 1)) + return _AccountInfo(id=account_id, version=account_version) + + @property + def _account_version(self) -> int: + """User account version. 2 means both database and engine v2 are supported. + + Returns: + int: Account version + + Raises: + AccountNotFoundError: No account found with provided name + """ + return self._account_info.version + + @property def account_id(self) -> str: """User account ID. @@ -143,17 +180,7 @@ def account_id(self) -> str: Raises: AccountNotFoundError: No account found with provided name """ - response = self.get( - url=self._api_endpoint.copy_with( - path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name) - ) - ) - if response.status_code == HttpxCodes.NOT_FOUND: - assert self.account_name is not None - raise AccountNotFoundOrNoAccessError(self.account_name) - # process all other status codes - response.raise_for_status() - return response.json()["id"] + return self._account_info.id def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any @@ -188,6 +215,13 @@ def __init__( ) self._auth_endpoint = URL(fix_url_schema(api_endpoint)) + @property + def _account_version(self) -> int: + """User account version. Hardcoded since it's not returned + by the backend for V1. + """ + return 1 + @cached_property def account_id(self) -> str: """User account ID. @@ -285,6 +319,11 @@ def __init__(self, *args: Any, **kwargs: Any): async def account_id(self) -> str: ... + @property + @abstractmethod + async def _account_version(self) -> int: + ... + class AsyncClientV2(AsyncClient): """An HTTP client, based on httpx.Client. @@ -309,24 +348,12 @@ def __init__( api_endpoint=api_endpoint, **kwargs, ) - self.acount_id_cache: Dict[str, str] = {} - - @property - async def account_id(self) -> str: - """User account ID. - - If account_name was provided during Client construction, returns its ID; - gets default account otherwise. - - Returns: - str: Account ID + self.acount_info_cache: Dict[str, _AccountInfo] = {} - Raises: - AccountNotFoundError: No account found with provided name - """ + async def _account_info(self) -> _AccountInfo: # manual caching to avoid async_cached_property issues - if self.account_name in self.acount_id_cache: - return self.acount_id_cache[self.account_name] + if self.account_name in self.acount_info_cache: + return self.acount_info_cache[self.account_name] response = await self.get( url=self._api_endpoint.copy_with( @@ -339,10 +366,39 @@ async def account_id(self) -> str: # process all other status codes response.raise_for_status() account_id = response.json()["id"] + account_version = int(response.json().get("infraVersion", 1)) + account_info = _AccountInfo(id=account_id, version=account_version) # cache for future use if self.account_name: - self.acount_id_cache[self.account_name] = account_id - return account_id + self.acount_info_cache[self.account_name] = account_info + return account_info + + @property + async def account_id(self) -> str: + """User account ID. + + If account_name was provided during Client construction, returns its ID; + gets default account otherwise. + + Returns: + str: Account ID + + Raises: + AccountNotFoundError: No account found with provided name + """ + return (await self._account_info()).id + + @property + async def _account_version(self) -> int: + """User account version. 2 means both database and engine v2 are supported. + + Returns: + int: Account version + + Raises: + AccountNotFoundError: No account found with provided name + """ + return (await self._account_info()).version async def _send_handling_redirects( self, request: Request, *args: Any, **kwargs: Any @@ -378,6 +434,12 @@ def __init__( self.acount_id_cache: Dict[str, str] = {} self._auth_endpoint = URL(fix_url_schema(api_endpoint)) + @property + async def _account_version(self) -> int: + """User account version. Hardcoded since it's not returned + by the backend for V1.""" + return 1 + @property async def account_id(self) -> str: """User account ID. diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 6a7f9d43ec..b48a34eb5c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,7 +13,8 @@ ENGINE_NAME_ENV = "ENGINE_NAME" STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME" DATABASE_NAME_ENV = "DATABASE_NAME" -ACCOUNT_NAME_ENV = "ACCOUNT_NAME" +ACCOUNT_NAME_V1_ENV = "ACCOUNT_NAME_V1" +ACCOUNT_NAME_V2_ENV = "ACCOUNT_NAME_V2" API_ENDPOINT_ENV = "API_ENDPOINT" SERVICE_ID_ENV = "SERVICE_ID" SERVICE_SECRET_ENV = "SERVICE_SECRET" @@ -93,7 +94,12 @@ def use_db_name(database_name: str): @fixture(scope="session") def account_name() -> str: - return must_env(ACCOUNT_NAME_ENV) + return must_env(ACCOUNT_NAME_V1_ENV) + + +@fixture(scope="session") +def account_name_v2() -> str: + return must_env(ACCOUNT_NAME_V2_ENV) @fixture(scope="session") diff --git a/tests/integration/dbapi/async/V2/conftest.py b/tests/integration/dbapi/async/V2/conftest.py index a32f1e1758..ae0b1e296f 100644 --- a/tests/integration/dbapi/async/V2/conftest.py +++ b/tests/integration/dbapi/async/V2/conftest.py @@ -59,6 +59,22 @@ async def connection_system_engine( yield connection +@fixture +async def connection_system_engine_v2( + database_name: str, + auth: Auth, + account_name_v2: str, + api_endpoint: str, +) -> Connection: + async with await connect( + database=database_name, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + yield connection + + @fixture async def connection_system_engine_no_db( auth: Auth, diff --git a/tests/integration/dbapi/async/V2/test_system_engine_async.py b/tests/integration/dbapi/async/V2/test_system_engine_async.py index 617503ad65..67e6138b77 100644 --- a/tests/integration/dbapi/async/V2/test_system_engine_async.py +++ b/tests/integration/dbapi/async/V2/test_system_engine_async.py @@ -16,6 +16,9 @@ async def test_system_engine( timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" + assert ( + await connection_system_engine._client._account_version + ) == 1, "Invalid account version" with connection_system_engine.cursor() as c: assert await c.execute(all_types_query) == 1, "Invalid row count returned" assert c.rowcount == 1, "Invalid rowcount value" @@ -68,3 +71,12 @@ async def test_system_engine_no_db( all_types_query_system_engine_response, timezone_name, ) + + +async def test_system_engine_v2_account(connection_system_engine_v2: Connection): + assert ( + await connection_system_engine_v2._client.account_id + ), "Can't get account id explicitly" + assert ( + await connection_system_engine_v2._client._account_version + ) == 2, "Invalid account version" diff --git a/tests/integration/dbapi/sync/V2/conftest.py b/tests/integration/dbapi/sync/V2/conftest.py index be1ec69543..351602a413 100644 --- a/tests/integration/dbapi/sync/V2/conftest.py +++ b/tests/integration/dbapi/sync/V2/conftest.py @@ -59,6 +59,22 @@ def connection_system_engine( yield connection +@fixture +def connection_system_engine_v2( + database_name: str, + auth: Auth, + account_name_v2: str, + api_endpoint: str, +) -> Connection: + with connect( + database=database_name, + auth=auth, + account_name=account_name_v2, + api_endpoint=api_endpoint, + ) as connection: + yield connection + + @fixture def connection_system_engine_no_db( auth: Auth, diff --git a/tests/integration/dbapi/sync/V2/test_system_engine.py b/tests/integration/dbapi/sync/V2/test_system_engine.py index 04ca5c9e27..9654dd2404 100644 --- a/tests/integration/dbapi/sync/V2/test_system_engine.py +++ b/tests/integration/dbapi/sync/V2/test_system_engine.py @@ -16,6 +16,9 @@ def test_system_engine( timezone_name: str, ) -> None: """Connecting with engine name is handled properly.""" + assert ( + connection_system_engine._client._account_version == 1 + ), "Invalid account version" with connection_system_engine.cursor() as c: assert c.execute(all_types_query) == 1, "Invalid row count returned" assert c.rowcount == 1, "Invalid rowcount value" @@ -68,3 +71,12 @@ def test_system_engine_no_db( all_types_query_system_engine_response, timezone_name, ) + + +def test_system_engine_v2_account(connection_system_engine_v2: Connection): + assert ( + connection_system_engine_v2._client.account_id + ), "Can't get account id explicitly" + assert ( + connection_system_engine_v2._client._account_version == 2 + ), "Invalid account version" diff --git a/tests/unit/client/V1/test_client.py b/tests/unit/client/V1/test_client.py index 325835bd41..916a6d8eae 100644 --- a/tests/unit/client/V1/test_client.py +++ b/tests/unit/client/V1/test_client.py @@ -112,6 +112,7 @@ def test_client_account_id( api_endpoint=server, ) as c: assert c.account_id == account_id, "Invalid account id returned" + assert c._account_version == 1, "Invalid account version returned" # FIR-14945 diff --git a/tests/unit/client/V1/test_client_async.py b/tests/unit/client/V1/test_client_async.py index 952c56b47f..fa4f9135bf 100644 --- a/tests/unit/client/V1/test_client_async.py +++ b/tests/unit/client/V1/test_client_async.py @@ -120,6 +120,7 @@ async def test_client_account_id( api_endpoint=server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." + assert await c._account_version == 1, "Invalid account version returned." async def test_concurent_auth_lock( diff --git a/tests/unit/client/V2/test_client.py b/tests/unit/client/V2/test_client.py index bae20e4edc..fbdcc692b0 100644 --- a/tests/unit/client/V2/test_client.py +++ b/tests/unit/client/V2/test_client.py @@ -107,8 +107,33 @@ def test_client_account_id( auth=auth, base_url=fix_url_schema(server), api_endpoint=server, - ) as c: - assert c.account_id == account_id, "Invalid account id returned" + ) as cursor: + assert cursor.account_id == account_id, "Invalid account id returned" + assert cursor._account_version == 1, "Invalid account version returned" + + +def test_client_account_v2( + httpx_mock: HTTPXMock, + auth: Auth, + account_name: str, + account_id: str, + account_id_url: Pattern, + account_id_v2_callback: Callable, + auth_url: str, + auth_callback: Callable, + server: str, +): + httpx_mock.add_callback(account_id_v2_callback, url=account_id_url) + httpx_mock.add_callback(auth_callback, url=auth_url) + + with ClientV2( + account_name=account_name, + auth=auth, + base_url=fix_url_schema(server), + api_endpoint=server, + ) as cursor: + assert cursor.account_id == account_id, "Invalid account id returned" + assert cursor._account_version == 2, "Invalid account version returned" # FIR-14945 diff --git a/tests/unit/client/V2/test_client_async.py b/tests/unit/client/V2/test_client_async.py index 35904cc0c8..6b82a595e3 100644 --- a/tests/unit/client/V2/test_client_async.py +++ b/tests/unit/client/V2/test_client_async.py @@ -113,6 +113,31 @@ async def test_client_account_id( api_endpoint=server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." + assert await c._account_version == 1, "Invalid account version returned" + + +async def test_client_account_v2( + httpx_mock: HTTPXMock, + auth: Auth, + account_name: str, + account_id: str, + account_id_url: Pattern, + account_id_v2_callback: Callable, + auth_url: str, + auth_callback: Callable, + server: str, +): + httpx_mock.add_callback(account_id_v2_callback, url=account_id_url) + httpx_mock.add_callback(auth_callback, url=auth_url) + + async with AsyncClient( + account_name=account_name, + auth=auth, + base_url=fix_url_schema(server), + api_endpoint=server, + ) as c: + assert await c.account_id == account_id, "Invalid account id returned." + assert await c._account_version == 2, "Invalid account version returned" async def test_concurent_auth_lock( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 37241c047d..0e9f9a088f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -73,6 +73,16 @@ def account_id() -> str: return "mock_account_id" +@fixture +def account_version_1() -> int: + return 1 + + +@fixture +def account_version_2() -> int: + return 2 + + @fixture def account_name() -> str: return "mock_account_name" @@ -163,6 +173,27 @@ def account_id_url(server: str, account_name: str) -> Pattern: @fixture def account_id_callback( account_id: str, + account_version_1: int, + account_name: str, +) -> Callable: + def do_mock( + request: Request, + **kwargs, + ) -> Response: + if request.url.path.split("/")[-2] != account_name: + raise AccountNotFoundError(request.url.path.split("/")[-2]) + return Response( + status_code=httpx.codes.OK, + json={"id": account_id, "infraVersion": account_version_1}, + ) + + return do_mock + + +@fixture +def account_id_v2_callback( + account_id: str, + account_version_2: int, account_name: str, ) -> Callable: def do_mock( @@ -171,7 +202,10 @@ def do_mock( ) -> Response: if request.url.path.split("/")[-2] != account_name: raise AccountNotFoundError(request.url.path.split("/")[-2]) - return Response(status_code=httpx.codes.OK, json={"id": account_id}) + return Response( + status_code=httpx.codes.OK, + json={"id": account_id, "infraVersion": account_version_2}, + ) return do_mock