diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 7a33d3ca42a..85d73e9010a 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -222,6 +222,18 @@ async def connect_inner( account_name=account_name, api_endpoint=api_endpoint, ) + elif account_name: + # In above if branches account name is validated since it's used to + # resolve or get an engine url. + # We need to manually validate account_name if none of the above + # cases are triggered. + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + await client.account_id assert engine_url is not None diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index b58ea120690..f56803201b3 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable, List from httpx import codes @@ -7,9 +8,10 @@ from firebolt.async_db import Connection, connect from firebolt.async_db._types import ColType -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common.settings import Settings from firebolt.utils.exception import ( + AccountNotFoundError, ConfigurationError, ConnectionClosedError, FireboltEngineError, @@ -71,7 +73,6 @@ async def test_cursor_initialized( database=db_name, username="u", password="p", - account_name="a", api_endpoint=settings.server, ) ) as connection: @@ -116,7 +117,6 @@ async def test_connect_access_token( engine_url=settings.server, database=db_name, access_token=access_token, - account_name="a", api_endpoint=settings.server, ) ) as connection: @@ -147,7 +147,7 @@ async def test_connect_engine_name( auth_url: str, query_callback: Callable, query_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, engine_id: str, get_engine_url: str, @@ -223,7 +223,7 @@ async def test_connect_default_engine( auth_url: str, query_callback: Callable, query_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, engine_id: str, get_engine_url: str, @@ -353,3 +353,37 @@ async def test_connect_with_auth( api_endpoint=settings.server, ) as connection: await connection.cursor().execute("select*") + + +@mark.asyncio +async def test_connect_account_name( + httpx_mock: HTTPXMock, + auth: Auth, + settings: Settings, + db_name: str, + auth_url: str, + check_credentials_callback: Callable, + account_id_url: Pattern, + account_id_callback: Callable, +): + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + with raises(AccountNotFoundError): + async with await connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name="invalid", + api_endpoint=settings.server, + ): + pass + + async with await connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + ): + pass diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index 7bd30c98f8b..bbf403450f8 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable from httpx import codes @@ -93,7 +94,7 @@ def test_client_account_id( test_username: str, test_password: str, account_id: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, auth_url: str, auth_callback: Callable, diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index 1cc02dc4449..727fd614a6c 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable from httpx import codes @@ -104,7 +105,7 @@ async def test_client_account_id( test_username: str, test_password: str, account_id: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, auth_url: str, auth_callback: Callable, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2bba104ed58..b9b2ae7f996 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,16 +1,19 @@ from json import loads +from re import Pattern, compile from typing import Callable, List import httpx -from httpx import Response +from httpx import Request, Response from pydantic import SecretStr from pyfakefs.fake_filesystem_unittest import Patcher from pytest import fixture +from firebolt.client.auth import Auth, UsernamePassword from firebolt.common.settings import Settings from firebolt.model.provider import Provider from firebolt.model.region import Region, RegionKey from firebolt.utils.exception import ( + AccountNotFoundError, DatabaseError, DataError, Error, @@ -51,6 +54,16 @@ def global_fake_fs(request) -> None: yield +@fixture +def username() -> str: + return "email@domain.com" + + +@fixture +def password() -> str: + return "*****" + + @fixture def server() -> str: return "api.mock.firebolt.io" @@ -107,16 +120,21 @@ def mock_regions(region_1, region_2) -> List[Region]: @fixture -def settings(server, region_1) -> Settings: +def settings(server: str, region_1: str, username: str, password: str) -> Settings: return Settings( server=server, - user="email@domain.com", - password=SecretStr("*****"), + user=username, + password=SecretStr(password), default_region=region_1.name, account_name=None, ) +@fixture +def auth(username: str, password: str) -> Auth: + return UsernamePassword(username, password) + + @fixture def auth_callback(auth_url: str) -> Callable: def do_mock( @@ -148,30 +166,30 @@ def db_description() -> str: @fixture -def account_id_url(settings: Settings) -> str: - if not settings.account_name: # if None or '' - return f"https://{settings.server}{ACCOUNT_URL}" - else: - return ( - f"https://{settings.server}{ACCOUNT_BY_NAME_URL}" - f"?account_name={settings.account_name}" - ) +def account_id_url(settings: Settings) -> Pattern: + base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}?account_name=" + default_base = f"https://{settings.server}{ACCOUNT_URL}" + base = base.replace("/", "\\/").replace("?", "\\?") + default_base = default_base.replace("/", "\\/").replace("?", "\\?") + return compile(f"(?:{base}.*|{default_base})") @fixture def account_id_callback( - account_id: str, account_id_url: str, settings: Settings + account_id: str, + settings: Settings, ) -> Callable: def do_mock( - request: httpx.Request = None, + request: Request, **kwargs, ) -> Response: - assert request.url == account_id_url - if account_id_url.endswith(ACCOUNT_URL): # account_name shouldn't be specified. + if "account_name" not in request.url.params: return Response( status_code=httpx.codes.OK, json={"account": {"id": account_id}} ) # In this case, an account_name *should* be specified. + if request.url.params["account_name"] != settings.account_name: + raise AccountNotFoundError(request.url.params["account_name"]) return Response(status_code=httpx.codes.OK, json={"account_id": account_id}) return do_mock @@ -194,7 +212,7 @@ def get_engine_callback( get_engine_url: str, engine_id: str, settings: Settings ) -> Callable: def do_mock( - request: httpx.Request = None, + request: Request = None, **kwargs, ) -> Response: assert request.url == get_engine_url @@ -230,7 +248,7 @@ def get_providers_url(settings: Settings, account_id: str, engine_id: str) -> st @fixture def get_providers_callback(get_providers_url: str, provider: Provider) -> Callable: def do_mock( - request: httpx.Request = None, + request: Request = None, **kwargs, ) -> Response: assert request.url == get_providers_url @@ -269,7 +287,7 @@ def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> s @fixture def database_by_name_callback(account_id: str, database_id: str) -> str: def do_mock( - request: httpx.Request = None, + request: Request = None, **kwargs, ) -> Response: return Response( @@ -312,7 +330,7 @@ def db_api_exceptions(): @fixture def check_token_callback(access_token: str) -> Callable: - def check_token(request: httpx.Request = None, **kwargs) -> Response: + def check_token(request: Request = None, **kwargs) -> Response: prefix = "Bearer " assert request, "empty request" assert "authorization" in request.headers, "missing authorization header" @@ -329,7 +347,7 @@ def check_token(request: httpx.Request = None, **kwargs) -> Response: @fixture def check_credentials_callback(settings: Settings, access_token: str) -> Callable: def check_credentials( - request: httpx.Request = None, + request: Request = None, **kwargs, ) -> Response: assert request, "empty request" diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 9572ddc151a..5d44a26e799 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable, List from httpx import codes @@ -6,10 +7,14 @@ from pytest_httpx import HTTPXMock from firebolt.async_db._types import ColType -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common.settings import Settings from firebolt.db import Connection, connect -from firebolt.utils.exception import ConfigurationError, ConnectionClosedError +from firebolt.utils.exception import ( + AccountNotFoundError, + ConfigurationError, + ConnectionClosedError, +) from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import ACCOUNT_ENGINE_BY_NAME_URL @@ -104,7 +109,6 @@ def test_connect_access_token( engine_url=settings.server, database=db_name, access_token=access_token, - account_name="a", api_endpoint=settings.server, ) ) as connection: @@ -134,7 +138,7 @@ def test_connect_engine_name( auth_url: str, query_callback: Callable, query_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, engine_id: str, get_engine_url: str, @@ -190,7 +194,7 @@ def test_connect_default_engine( auth_url: str, query_callback: Callable, query_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, engine_id: str, get_engine_url: str, @@ -323,3 +327,36 @@ def test_connect_with_auth( api_endpoint=settings.server, ) as connection: connection.cursor().execute("select*") + + +def test_connect_account_name( + httpx_mock: HTTPXMock, + auth: Auth, + settings: Settings, + db_name: str, + auth_url: str, + check_credentials_callback: Callable, + account_id_url: Pattern, + account_id_callback: Callable, +): + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + with raises(AccountNotFoundError): + with connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name="invalid", + api_endpoint=settings.server, + ): + pass + + with connect( + auth=auth, + database=db_name, + engine_url=settings.server, + account_name=settings.account_name, + api_endpoint=settings.server, + ): + pass diff --git a/tests/unit/service/test_database.py b/tests/unit/service/test_database.py index 1b5bb1ea726..35dec331ad6 100644 --- a/tests/unit/service/test_database.py +++ b/tests/unit/service/test_database.py @@ -1,4 +1,4 @@ -import re +from re import Pattern, compile from typing import Callable from pytest_httpx import HTTPXMock @@ -18,7 +18,7 @@ def test_database_create( region_url: str, settings: Settings, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, create_databases_callback: Callable, databases_url: str, db_name: str, @@ -46,7 +46,7 @@ def test_database_get_by_name( provider_url: str, settings: Settings, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, database_get_by_name_callback: Callable, database_get_by_name_url: str, database_get_callback: Callable, @@ -75,7 +75,7 @@ def test_database_get_many( provider_url: str, settings: Settings, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, database_get_by_name_callback: Callable, database_get_by_name_url: str, databases_get_callback: Callable, @@ -89,7 +89,7 @@ def test_database_get_many( httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( databases_get_callback, - url=re.compile(databases_url + "?[a-zA-Z0-9=&]*"), + url=compile(databases_url + "?[a-zA-Z0-9=&]*"), method="GET", ) @@ -112,7 +112,7 @@ def test_database_update( provider_url: str, settings: Settings, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, database_update_callback: Callable, database_url: str, mock_database: Database, diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 50bea0cb3bd..18b93af5c44 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable, List from pydantic import ValidationError @@ -29,7 +30,7 @@ def test_engine_create( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, ): @@ -63,7 +64,7 @@ def test_engine_create_with_kwargs( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, account_id: str, @@ -121,7 +122,7 @@ def test_engine_create_with_kwargs_fail( settings: Settings, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, ): httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(provider_callback, url=provider_url) @@ -158,7 +159,7 @@ def test_engine_create_no_available_types( mock_instance_types: List[InstanceType], engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_url: str, region_2: Region, ): @@ -192,7 +193,7 @@ def test_engine_no_attached_database( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, account_engine_callback: Callable, @@ -236,7 +237,7 @@ def test_engine_start_binding_to_missing_database( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, database_not_found_callback: Callable, @@ -279,7 +280,7 @@ def test_get_connection( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, db_name: str, @@ -320,7 +321,7 @@ def test_attach_to_database( instance_type_region_1_url: str, settings: Settings, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, create_databases_callback: Callable, databases_url: str, database_get_callback: Callable, @@ -382,7 +383,7 @@ def test_engine_update( mock_engine: Engine, engine_name: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, engine_url: str, account_engine_url: str, @@ -418,7 +419,7 @@ def test_engine_restart( settings: Settings, mock_engine: Engine, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, engine_callback: Callable, account_engine_url: str, bindings_callback: Callable, diff --git a/tests/unit/service/test_instance_type.py b/tests/unit/service/test_instance_type.py index d1e9510e35e..dce917c1f42 100644 --- a/tests/unit/service/test_instance_type.py +++ b/tests/unit/service/test_instance_type.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable, List from pytest_httpx import HTTPXMock @@ -21,7 +22,7 @@ def test_instance_type( instance_type_region_1_url: str, instance_type_region_2_url: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, settings: Settings, mock_instance_types: List[InstanceType], cheapest_instance: InstanceType, diff --git a/tests/unit/service/test_region.py b/tests/unit/service/test_region.py index 121e09bdc9e..939238f6bfe 100644 --- a/tests/unit/service/test_region.py +++ b/tests/unit/service/test_region.py @@ -1,3 +1,4 @@ +from re import Pattern from typing import Callable, List from pytest_httpx import HTTPXMock @@ -16,7 +17,7 @@ def test_region( region_callback: Callable, region_url: str, account_id_callback: Callable, - account_id_url: str, + account_id_url: Pattern, settings: Settings, mock_regions: List[Region], ): diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 9c5e8e585e8..778025cca7a 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -1,12 +1,14 @@ +from re import Pattern from typing import Callable from pyfakefs.fake_filesystem_unittest import Patcher -from pytest import mark +from pytest import mark, raises from pytest_httpx import HTTPXMock -from firebolt.client.auth import Token, UsernamePassword +from firebolt.client.auth import Auth, Token, UsernamePassword from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager +from firebolt.utils.exception import AccountNotFoundError from firebolt.utils.token_storage import TokenSecureStorage @@ -16,7 +18,7 @@ def test_rm_credentials( check_credentials_callback: Callable, settings: Settings, auth_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, provider_callback: Callable, provider_url: str, @@ -68,7 +70,7 @@ def test_rm_token_cache( check_credentials_callback: Callable, settings: Settings, auth_url: str, - account_id_url: str, + account_id_url: Pattern, account_id_callback: Callable, provider_callback: Callable, provider_url: str, @@ -112,3 +114,27 @@ def test_rm_token_cache( assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" + + +def test_rm_invalid_account_name( + httpx_mock: HTTPXMock, + auth: Auth, + settings: Settings, + check_credentials_callback: Callable, + auth_url: str, + account_id_url: Pattern, + account_id_callback: Callable, +) -> None: + """Resource manager raises an error on invalid account name.""" + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + local_settings = Settings( + auth=auth, + account_name="invalid", + server=settings.server, + default_region=settings.default_region, + ) + + with raises(AccountNotFoundError): + ResourceManager(local_settings)