Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/integration-tests-v2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
with:
firebolt-client-id: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }}
firebolt-client-secret: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }}
account: ${{ vars.FIREBOLT_ACCOUNT }}
account: ${{ vars.FIREBOLT_ACCOUNT_V1 }}
api-endpoint: "api.staging.firebolt.io"

- name: Run integration tests
Expand All @@ -42,6 +42,7 @@ jobs:
ENGINE_NAME: ${{ steps.setup.outputs.engine_name }}
STOPPED_ENGINE_NAME: ${{ steps.setup.outputs.stopped_engine_name }}
API_ENDPOINT: "api.staging.firebolt.io"
ACCOUNT_NAME: ${{ vars.FIREBOLT_ACCOUNT }}
ACCOUNT_NAME_V1: ${{ vars.FIREBOLT_ACCOUNT_V1 }}
ACCOUNT_NAME_V2: ${{ vars.FIREBOLT_ACCOUNT_V2 }}
run: |
pytest -n 6 --dist loadgroup --timeout_method "signal" -o log_cli=true -o log_cli_level=WARNING tests/integration -k "not V1" --runslow
122 changes: 92 additions & 30 deletions src/firebolt/client/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from json import JSONDecodeError
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -44,6 +45,8 @@

FireboltClientMixinBase = mixin_for(HttpxClient) # type: Any

_AccountInfo = namedtuple("_AccountInfo", ["id", "version"])


class FireboltClientMixin(FireboltClientMixinBase):
"""HttpxAsyncClient mixin with Firebolt authentication functionality."""
Expand Down Expand Up @@ -100,11 +103,16 @@ class Client(FireboltClientMixin, HttpxClient, metaclass=ABCMeta):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs, transport=KeepaliveTransport())

@cached_property
@property
@abstractmethod
def account_id(self) -> str:
...

@property
@abstractmethod
def _account_version(self) -> int:
...


class ClientV2(Client):
"""An HTTP client, based on httpx.Client.
Expand All @@ -131,6 +139,35 @@ def __init__(
)

@cached_property
def _account_info(self) -> _AccountInfo:
response = self.get(
url=self._api_endpoint.copy_with(
path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name)
)
)
if response.status_code == HttpxCodes.NOT_FOUND:
assert self.account_name is not None
raise AccountNotFoundOrNoAccessError(self.account_name)
# process all other status codes
response.raise_for_status()
account_id = response.json()["id"]
# If no version assume 1
account_version = int(response.json().get("infraVersion", 1))
return _AccountInfo(id=account_id, version=account_version)

@property
def _account_version(self) -> int:
"""User account version. 2 means both database and engine v2 are supported.

Returns:
int: Account version

Raises:
AccountNotFoundError: No account found with provided name
"""
return self._account_info.version

@property
def account_id(self) -> str:
"""User account ID.

Expand All @@ -143,17 +180,7 @@ def account_id(self) -> str:
Raises:
AccountNotFoundError: No account found with provided name
"""
response = self.get(
url=self._api_endpoint.copy_with(
path=ACCOUNT_BY_NAME_URL.format(account_name=self.account_name)
)
)
if response.status_code == HttpxCodes.NOT_FOUND:
assert self.account_name is not None
raise AccountNotFoundOrNoAccessError(self.account_name)
# process all other status codes
response.raise_for_status()
return response.json()["id"]
return self._account_info.id

def _send_handling_redirects(
self, request: Request, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -188,6 +215,13 @@ def __init__(
)
self._auth_endpoint = URL(fix_url_schema(api_endpoint))

@property
def _account_version(self) -> int:
"""User account version. Hardcoded since it's not returned
by the backend for V1.
"""
return 1

@cached_property
def account_id(self) -> str:
"""User account ID.
Expand Down Expand Up @@ -285,6 +319,11 @@ def __init__(self, *args: Any, **kwargs: Any):
async def account_id(self) -> str:
...

@property
@abstractmethod
async def _account_version(self) -> int:
...


class AsyncClientV2(AsyncClient):
"""An HTTP client, based on httpx.Client.
Expand All @@ -309,24 +348,12 @@ def __init__(
api_endpoint=api_endpoint,
**kwargs,
)
self.acount_id_cache: Dict[str, str] = {}

@property
async def account_id(self) -> str:
"""User account ID.

If account_name was provided during Client construction, returns its ID;
gets default account otherwise.

Returns:
str: Account ID
self.acount_info_cache: Dict[str, _AccountInfo] = {}

Raises:
AccountNotFoundError: No account found with provided name
"""
async def _account_info(self) -> _AccountInfo:
# manual caching to avoid async_cached_property issues
if self.account_name in self.acount_id_cache:
return self.acount_id_cache[self.account_name]
if self.account_name in self.acount_info_cache:
return self.acount_info_cache[self.account_name]

response = await self.get(
url=self._api_endpoint.copy_with(
Expand All @@ -339,10 +366,39 @@ async def account_id(self) -> str:
# process all other status codes
response.raise_for_status()
account_id = response.json()["id"]
account_version = int(response.json().get("infraVersion", 1))
account_info = _AccountInfo(id=account_id, version=account_version)
# cache for future use
if self.account_name:
self.acount_id_cache[self.account_name] = account_id
return account_id
self.acount_info_cache[self.account_name] = account_info
return account_info

@property
async def account_id(self) -> str:
"""User account ID.

If account_name was provided during Client construction, returns its ID;
gets default account otherwise.

Returns:
str: Account ID

Raises:
AccountNotFoundError: No account found with provided name
"""
return (await self._account_info()).id

@property
async def _account_version(self) -> int:
"""User account version. 2 means both database and engine v2 are supported.

Returns:
int: Account version

Raises:
AccountNotFoundError: No account found with provided name
"""
return (await self._account_info()).version

async def _send_handling_redirects(
self, request: Request, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -378,6 +434,12 @@ def __init__(
self.acount_id_cache: Dict[str, str] = {}
self._auth_endpoint = URL(fix_url_schema(api_endpoint))

@property
async def _account_version(self) -> int:
"""User account version. Hardcoded since it's not returned
by the backend for V1."""
return 1

@property
async def account_id(self) -> str:
"""User account ID.
Expand Down
10 changes: 8 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
ENGINE_NAME_ENV = "ENGINE_NAME"
STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME"
DATABASE_NAME_ENV = "DATABASE_NAME"
ACCOUNT_NAME_ENV = "ACCOUNT_NAME"
ACCOUNT_NAME_V1_ENV = "ACCOUNT_NAME_V1"
ACCOUNT_NAME_V2_ENV = "ACCOUNT_NAME_V2"
API_ENDPOINT_ENV = "API_ENDPOINT"
SERVICE_ID_ENV = "SERVICE_ID"
SERVICE_SECRET_ENV = "SERVICE_SECRET"
Expand Down Expand Up @@ -93,7 +94,12 @@ def use_db_name(database_name: str):

@fixture(scope="session")
def account_name() -> str:
return must_env(ACCOUNT_NAME_ENV)
return must_env(ACCOUNT_NAME_V1_ENV)


@fixture(scope="session")
def account_name_v2() -> str:
return must_env(ACCOUNT_NAME_V2_ENV)


@fixture(scope="session")
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/dbapi/async/V2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ async def connection_system_engine(
yield connection


@fixture
async def connection_system_engine_v2(
database_name: str,
auth: Auth,
account_name_v2: str,
api_endpoint: str,
) -> Connection:
async with await connect(
database=database_name,
auth=auth,
account_name=account_name_v2,
api_endpoint=api_endpoint,
) as connection:
yield connection


@fixture
async def connection_system_engine_no_db(
auth: Auth,
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/dbapi/async/V2/test_system_engine_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ async def test_system_engine(
timezone_name: str,
) -> None:
"""Connecting with engine name is handled properly."""
assert (
await connection_system_engine._client._account_version
) == 1, "Invalid account version"
with connection_system_engine.cursor() as c:
assert await c.execute(all_types_query) == 1, "Invalid row count returned"
assert c.rowcount == 1, "Invalid rowcount value"
Expand Down Expand Up @@ -68,3 +71,12 @@ async def test_system_engine_no_db(
all_types_query_system_engine_response,
timezone_name,
)


async def test_system_engine_v2_account(connection_system_engine_v2: Connection):
assert (
await connection_system_engine_v2._client.account_id
), "Can't get account id explicitly"
assert (
await connection_system_engine_v2._client._account_version
) == 2, "Invalid account version"
16 changes: 16 additions & 0 deletions tests/integration/dbapi/sync/V2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def connection_system_engine(
yield connection


@fixture
def connection_system_engine_v2(
database_name: str,
auth: Auth,
account_name_v2: str,
api_endpoint: str,
) -> Connection:
with connect(
database=database_name,
auth=auth,
account_name=account_name_v2,
api_endpoint=api_endpoint,
) as connection:
yield connection


@fixture
def connection_system_engine_no_db(
auth: Auth,
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/dbapi/sync/V2/test_system_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def test_system_engine(
timezone_name: str,
) -> None:
"""Connecting with engine name is handled properly."""
assert (
connection_system_engine._client._account_version == 1
), "Invalid account version"
with connection_system_engine.cursor() as c:
assert c.execute(all_types_query) == 1, "Invalid row count returned"
assert c.rowcount == 1, "Invalid rowcount value"
Expand Down Expand Up @@ -68,3 +71,12 @@ def test_system_engine_no_db(
all_types_query_system_engine_response,
timezone_name,
)


def test_system_engine_v2_account(connection_system_engine_v2: Connection):
assert (
connection_system_engine_v2._client.account_id
), "Can't get account id explicitly"
assert (
connection_system_engine_v2._client._account_version == 2
), "Invalid account version"
1 change: 1 addition & 0 deletions tests/unit/client/V1/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_client_account_id(
api_endpoint=server,
) as c:
assert c.account_id == account_id, "Invalid account id returned"
assert c._account_version == 1, "Invalid account version returned"


# FIR-14945
Expand Down
1 change: 1 addition & 0 deletions tests/unit/client/V1/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def test_client_account_id(
api_endpoint=server,
) as c:
assert await c.account_id == account_id, "Invalid account id returned."
assert await c._account_version == 1, "Invalid account version returned."


async def test_concurent_auth_lock(
Expand Down
29 changes: 27 additions & 2 deletions tests/unit/client/V2/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,33 @@ def test_client_account_id(
auth=auth,
base_url=fix_url_schema(server),
api_endpoint=server,
) as c:
assert c.account_id == account_id, "Invalid account id returned"
) as cursor:
assert cursor.account_id == account_id, "Invalid account id returned"
assert cursor._account_version == 1, "Invalid account version returned"


def test_client_account_v2(
httpx_mock: HTTPXMock,
auth: Auth,
account_name: str,
account_id: str,
account_id_url: Pattern,
account_id_v2_callback: Callable,
auth_url: str,
auth_callback: Callable,
server: str,
):
httpx_mock.add_callback(account_id_v2_callback, url=account_id_url)
httpx_mock.add_callback(auth_callback, url=auth_url)

with ClientV2(
account_name=account_name,
auth=auth,
base_url=fix_url_schema(server),
api_endpoint=server,
) as cursor:
assert cursor.account_id == account_id, "Invalid account id returned"
assert cursor._account_version == 2, "Invalid account version returned"


# FIR-14945
Expand Down
Loading