Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ async def connect_inner(
account_name=account_name,
api_endpoint=api_endpoint,
)
elif account_name:
# In above if branches account name is validated since it's used to
# resolve or get an engine url.
# We need to manually validate account_name if none of the above
# cases are triggered.
async with AsyncClient(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
) as client:
await client.account_id

assert engine_url is not None

Expand Down
44 changes: 39 additions & 5 deletions tests/unit/async_db/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import Pattern
from typing import Callable, List

from httpx import codes
Expand All @@ -7,9 +8,10 @@

from firebolt.async_db import Connection, connect
from firebolt.async_db._types import ColType
from firebolt.client.auth import Token, UsernamePassword
from firebolt.client.auth import Auth, Token, UsernamePassword
from firebolt.common.settings import Settings
from firebolt.utils.exception import (
AccountNotFoundError,
ConfigurationError,
ConnectionClosedError,
FireboltEngineError,
Expand Down Expand Up @@ -71,7 +73,6 @@ async def test_cursor_initialized(
database=db_name,
username="u",
password="p",
account_name="a",
api_endpoint=settings.server,
)
) as connection:
Expand Down Expand Up @@ -116,7 +117,6 @@ async def test_connect_access_token(
engine_url=settings.server,
database=db_name,
access_token=access_token,
account_name="a",
api_endpoint=settings.server,
)
) as connection:
Expand Down Expand Up @@ -147,7 +147,7 @@ async def test_connect_engine_name(
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
Expand Down Expand Up @@ -223,7 +223,7 @@ async def test_connect_default_engine(
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
Expand Down Expand Up @@ -353,3 +353,37 @@ async def test_connect_with_auth(
api_endpoint=settings.server,
) as connection:
await connection.cursor().execute("select*")


@mark.asyncio
async def test_connect_account_name(
httpx_mock: HTTPXMock,
auth: Auth,
settings: Settings,
db_name: str,
auth_url: str,
check_credentials_callback: Callable,
account_id_url: Pattern,
account_id_callback: Callable,
):
httpx_mock.add_callback(check_credentials_callback, url=auth_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)

with raises(AccountNotFoundError):
async with await connect(
auth=auth,
database=db_name,
engine_url=settings.server,
account_name="invalid",
api_endpoint=settings.server,
):
pass

async with await connect(
auth=auth,
database=db_name,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
):
pass
3 changes: 2 additions & 1 deletion tests/unit/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import Pattern
from typing import Callable

from httpx import codes
Expand Down Expand Up @@ -93,7 +94,7 @@ def test_client_account_id(
test_username: str,
test_password: str,
account_id: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
auth_url: str,
auth_callback: Callable,
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/client/test_client_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import Pattern
from typing import Callable

from httpx import codes
Expand Down Expand Up @@ -104,7 +105,7 @@ async def test_client_account_id(
test_username: str,
test_password: str,
account_id: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
auth_url: str,
auth_callback: Callable,
Expand Down
60 changes: 39 additions & 21 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from json import loads
from re import Pattern, compile
from typing import Callable, List

import httpx
from httpx import Response
from httpx import Request, Response
from pydantic import SecretStr
from pyfakefs.fake_filesystem_unittest import Patcher
from pytest import fixture

from firebolt.client.auth import Auth, UsernamePassword
from firebolt.common.settings import Settings
from firebolt.model.provider import Provider
from firebolt.model.region import Region, RegionKey
from firebolt.utils.exception import (
AccountNotFoundError,
DatabaseError,
DataError,
Error,
Expand Down Expand Up @@ -51,6 +54,16 @@ def global_fake_fs(request) -> None:
yield


@fixture
def username() -> str:
return "email@domain.com"


@fixture
def password() -> str:
return "*****"


@fixture
def server() -> str:
return "api.mock.firebolt.io"
Expand Down Expand Up @@ -107,16 +120,21 @@ def mock_regions(region_1, region_2) -> List[Region]:


@fixture
def settings(server, region_1) -> Settings:
def settings(server: str, region_1: str, username: str, password: str) -> Settings:
return Settings(
server=server,
user="email@domain.com",
password=SecretStr("*****"),
user=username,
password=SecretStr(password),
default_region=region_1.name,
account_name=None,
)


@fixture
def auth(username: str, password: str) -> Auth:
return UsernamePassword(username, password)


@fixture
def auth_callback(auth_url: str) -> Callable:
def do_mock(
Expand Down Expand Up @@ -148,30 +166,30 @@ def db_description() -> str:


@fixture
def account_id_url(settings: Settings) -> str:
if not settings.account_name: # if None or ''
return f"https://{settings.server}{ACCOUNT_URL}"
else:
return (
f"https://{settings.server}{ACCOUNT_BY_NAME_URL}"
f"?account_name={settings.account_name}"
)
def account_id_url(settings: Settings) -> Pattern:
base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}?account_name="
default_base = f"https://{settings.server}{ACCOUNT_URL}"
base = base.replace("/", "\\/").replace("?", "\\?")
default_base = default_base.replace("/", "\\/").replace("?", "\\?")
return compile(f"(?:{base}.*|{default_base})")


@fixture
def account_id_callback(
account_id: str, account_id_url: str, settings: Settings
account_id: str,
settings: Settings,
) -> Callable:
def do_mock(
request: httpx.Request = None,
request: Request,
**kwargs,
) -> Response:
assert request.url == account_id_url
if account_id_url.endswith(ACCOUNT_URL): # account_name shouldn't be specified.
if "account_name" not in request.url.params:
return Response(
status_code=httpx.codes.OK, json={"account": {"id": account_id}}
)
# In this case, an account_name *should* be specified.
if request.url.params["account_name"] != settings.account_name:
raise AccountNotFoundError(request.url.params["account_name"])
return Response(status_code=httpx.codes.OK, json={"account_id": account_id})

return do_mock
Expand All @@ -194,7 +212,7 @@ def get_engine_callback(
get_engine_url: str, engine_id: str, settings: Settings
) -> Callable:
def do_mock(
request: httpx.Request = None,
request: Request = None,
**kwargs,
) -> Response:
assert request.url == get_engine_url
Expand Down Expand Up @@ -230,7 +248,7 @@ def get_providers_url(settings: Settings, account_id: str, engine_id: str) -> st
@fixture
def get_providers_callback(get_providers_url: str, provider: Provider) -> Callable:
def do_mock(
request: httpx.Request = None,
request: Request = None,
**kwargs,
) -> Response:
assert request.url == get_providers_url
Expand Down Expand Up @@ -269,7 +287,7 @@ def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> s
@fixture
def database_by_name_callback(account_id: str, database_id: str) -> str:
def do_mock(
request: httpx.Request = None,
request: Request = None,
**kwargs,
) -> Response:
return Response(
Expand Down Expand Up @@ -312,7 +330,7 @@ def db_api_exceptions():

@fixture
def check_token_callback(access_token: str) -> Callable:
def check_token(request: httpx.Request = None, **kwargs) -> Response:
def check_token(request: Request = None, **kwargs) -> Response:
prefix = "Bearer "
assert request, "empty request"
assert "authorization" in request.headers, "missing authorization header"
Expand All @@ -329,7 +347,7 @@ def check_token(request: httpx.Request = None, **kwargs) -> Response:
@fixture
def check_credentials_callback(settings: Settings, access_token: str) -> Callable:
def check_credentials(
request: httpx.Request = None,
request: Request = None,
**kwargs,
) -> Response:
assert request, "empty request"
Expand Down
47 changes: 42 additions & 5 deletions tests/unit/db/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import Pattern
from typing import Callable, List

from httpx import codes
Expand All @@ -6,10 +7,14 @@
from pytest_httpx import HTTPXMock

from firebolt.async_db._types import ColType
from firebolt.client.auth import Token, UsernamePassword
from firebolt.client.auth import Auth, Token, UsernamePassword
from firebolt.common.settings import Settings
from firebolt.db import Connection, connect
from firebolt.utils.exception import ConfigurationError, ConnectionClosedError
from firebolt.utils.exception import (
AccountNotFoundError,
ConfigurationError,
ConnectionClosedError,
)
from firebolt.utils.token_storage import TokenSecureStorage
from firebolt.utils.urls import ACCOUNT_ENGINE_BY_NAME_URL

Expand Down Expand Up @@ -104,7 +109,6 @@ def test_connect_access_token(
engine_url=settings.server,
database=db_name,
access_token=access_token,
account_name="a",
api_endpoint=settings.server,
)
) as connection:
Expand Down Expand Up @@ -134,7 +138,7 @@ def test_connect_engine_name(
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
Expand Down Expand Up @@ -190,7 +194,7 @@ def test_connect_default_engine(
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_url: Pattern,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
Expand Down Expand Up @@ -323,3 +327,36 @@ def test_connect_with_auth(
api_endpoint=settings.server,
) as connection:
connection.cursor().execute("select*")


def test_connect_account_name(
httpx_mock: HTTPXMock,
auth: Auth,
settings: Settings,
db_name: str,
auth_url: str,
check_credentials_callback: Callable,
account_id_url: Pattern,
account_id_callback: Callable,
):
httpx_mock.add_callback(check_credentials_callback, url=auth_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)

with raises(AccountNotFoundError):
with connect(
auth=auth,
database=db_name,
engine_url=settings.server,
account_name="invalid",
api_endpoint=settings.server,
):
pass

with connect(
auth=auth,
database=db_name,
engine_url=settings.server,
account_name=settings.account_name,
api_endpoint=settings.server,
):
pass
Loading