diff --git a/src/firebolt/async_db/connection.py b/src/firebolt/async_db/connection.py index 8bd4ec17066..be517e57fc6 100644 --- a/src/firebolt/async_db/connection.py +++ b/src/firebolt/async_db/connection.py @@ -18,7 +18,11 @@ FireboltEngineError, InterfaceError, ) -from firebolt.common.urls import ACCOUNT_ENGINE_BY_NAME_URL, ACCOUNT_ENGINE_URL +from firebolt.common.urls import ( + ACCOUNT_ENGINE_BY_NAME_URL, + ACCOUNT_ENGINE_URL, + ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, +) from firebolt.common.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 5 @@ -65,6 +69,36 @@ async def _resolve_engine_url( raise InterfaceError(f"Unable to retrieve engine endpoint: {e}.") +async def _get_database_default_engine_url( + database: str, + auth: AuthTypes, + api_endpoint: str, + account_name: Optional[str] = None, +) -> str: + async with AsyncClient( + auth=auth, + base_url=api_endpoint, + account_name=account_name, + api_endpoint=api_endpoint, + ) as client: + try: + account_id = await client.account_id + response = await client.get( + url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id), + params={"database_name": database}, + ) + response.raise_for_status() + return response.json()["engine_url"] + except ( + JSONDecodeError, + RequestError, + RuntimeError, + HTTPStatusError, + KeyError, + ) as e: + raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.") + + def _validate_engine_name_and_url( engine_name: Optional[str], engine_url: Optional[str] ) -> None: @@ -72,10 +106,6 @@ def _validate_engine_name_and_url( raise ConfigurationError( "Both engine_name and engine_url are provided. Provide only one to connect." ) - if not engine_name and not engine_url: - raise ConfigurationError( - "Neither engine_name nor engine_url is provided. Provide one to connect." - ) def _get_auth( @@ -120,7 +150,7 @@ async def connect_inner( api_endpoint(optional): Firebolt API endpoint. Used for authentication. Note: - Either `engine_name` or `engine_url` should be provided, but not both. + Providing both `engine_name` and `engine_url` would result in an error. """ # These parameters are optional in function signature @@ -136,7 +166,15 @@ async def connect_inner( # Mypy checks, this should never happen assert database is not None - if engine_name: + if not engine_name and not engine_url: + engine_url = await _get_database_default_engine_url( + database=database, + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + + elif engine_name: engine_url = await _resolve_engine_url( engine_name=engine_name, auth=auth, diff --git a/src/firebolt/common/urls.py b/src/firebolt/common/urls.py index 07a29d74d89..470b5e81aca 100644 --- a/src/firebolt/common/urls.py +++ b/src/firebolt/common/urls.py @@ -15,6 +15,7 @@ ACCOUNT_ENGINES_URL = "/core/v1/accounts/{account_id}/engines" ACCOUNT_ENGINE_BY_NAME_URL = ACCOUNT_ENGINES_URL + ":getIdByName" ACCOUNT_ENGINE_REVISION_URL = ACCOUNT_ENGINE_URL + "/engineRevisions/{revision_id}" +ACCOUNT_ENGINE_URL_BY_DATABASE_NAME = ACCOUNT_ENGINES_URL + ":getURLByDatabaseName" ACCOUNT_DATABASES_URL = "/core/v1/accounts/{account_id}/databases" ACCOUNT_DATABASE_URL = "/core/v1/accounts/{account_id}/databases/{database_id}" diff --git a/tests/integration/dbapi/async/conftest.py b/tests/integration/dbapi/async/conftest.py index 4354d03d98e..61282b6bcd7 100644 --- a/tests/integration/dbapi/async/conftest.py +++ b/tests/integration/dbapi/async/conftest.py @@ -42,3 +42,22 @@ async def connection_engine_name( api_endpoint=api_endpoint, ) as connection: yield connection + + +@fixture +async def connection_no_engine( + database_name: str, + username: str, + password: str, + account_name: str, + api_endpoint: str, +) -> Connection: + + async with await connect( + database=database_name, + username=username, + password=password, + account_name=account_name, + api_endpoint=api_endpoint, + ) as connection: + yield connection diff --git a/tests/integration/dbapi/async/test_queries_async.py b/tests/integration/dbapi/async/test_queries_async.py index 3292a09f192..fbe5ee329ff 100644 --- a/tests/integration/dbapi/async/test_queries_async.py +++ b/tests/integration/dbapi/async/test_queries_async.py @@ -32,6 +32,22 @@ async def test_connect_engine_name( ) +@mark.asyncio +async def test_connect_no_engine( + connection_no_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_response: List[ColType], +) -> None: + """Connecting with engine name is handled properly.""" + await test_select( + connection_no_engine, + all_types_query, + all_types_query_description, + all_types_query_response, + ) + + @mark.asyncio async def test_select( connection: Connection, diff --git a/tests/integration/dbapi/sync/conftest.py b/tests/integration/dbapi/sync/conftest.py index c518bd2c83e..b55eb46c30d 100644 --- a/tests/integration/dbapi/sync/conftest.py +++ b/tests/integration/dbapi/sync/conftest.py @@ -43,3 +43,22 @@ def connection_engine_name( ) yield connection connection.close() + + +@fixture +def connection_no_engine( + database_name: str, + username: str, + password: str, + account_name: str, + api_endpoint: str, +) -> Connection: + connection = connect( + database=database_name, + username=username, + password=password, + account_name=account_name, + api_endpoint=api_endpoint, + ) + yield connection + connection.close() diff --git a/tests/integration/dbapi/sync/test_queries.py b/tests/integration/dbapi/sync/test_queries.py index 8e4ed16c44e..4900a94ef16 100644 --- a/tests/integration/dbapi/sync/test_queries.py +++ b/tests/integration/dbapi/sync/test_queries.py @@ -31,6 +31,21 @@ def test_connect_engine_name( ) +def test_connect_no_engine( + connection_no_engine: Connection, + all_types_query: str, + all_types_query_description: List[Column], + all_types_query_response: List[ColType], +) -> None: + """Connecting with engine name is handled properly.""" + test_select( + connection_no_engine, + all_types_query, + all_types_query_description, + all_types_query_response, + ) + + def test_select( connection: Connection, all_types_query: str, diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 938ab85c743..314f1a74337 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -149,8 +149,6 @@ async def test_connect_engine_name( engine_id: str, get_engine_url: str, get_engine_callback: Callable, - get_providers_url: str, - get_providers_callback: Callable, python_query_data: List[List[ColType]], account_id: str, ): @@ -167,15 +165,6 @@ async def test_connect_engine_name( ): pass - with raises(ConfigurationError): - async with await connect( - database="db", - username="username", - password="password", - account_name="account", - ): - pass - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) @@ -222,6 +211,49 @@ async def test_connect_engine_name( assert await connection.cursor().execute("select*") == len(python_query_data) +@mark.asyncio +async def test_connect_default_engine( + settings: Settings, + db_name: str, + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + query_url: str, + account_id_url: str, + account_id_callback: Callable, + engine_id: str, + get_engine_url: str, + get_engine_callback: Callable, + database_by_name_url: str, + database_by_name_callback: Callable, + database_id: str, + engine_by_db_url: str, + python_query_data: List[List[ColType]], + account_id: str, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" + + httpx_mock.add_response( + url=engine_by_db_url, + status_code=codes.OK, + json={ + "engine_url": settings.server, + }, + ) + async with await connect( + database=db_name, + username="u", + password="p", + account_name=settings.account_name, + api_endpoint=settings.server, + ) as connection: + assert await connection.cursor().execute("select*") == len(python_query_data) + + @mark.asyncio async def test_connection_commit(connection: Connection): # nothing happens diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8c5849ace9e..6d37ceebb53 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -22,7 +22,9 @@ from firebolt.common.settings import Settings from firebolt.common.urls import ( ACCOUNT_BY_NAME_URL, + ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_ENGINE_URL, + ACCOUNT_ENGINE_URL_BY_DATABASE_NAME, ACCOUNT_URL, AUTH_URL, DATABASES_URL, @@ -242,6 +244,47 @@ def get_databases_url(settings: Settings) -> str: return f"https://{settings.server}{DATABASES_URL}" +@fixture +def database_id() -> str: + return "database_id" + + +@fixture +def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> str: + return ( + f"https://{settings.server}" + f"{ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id)}" + f"?database_name={db_name}" + ) + + +@fixture +def database_by_name_callback(account_id: str, database_id: str) -> str: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + return Response( + status_code=httpx.codes.OK, + json={ + "database_id": { + "database_id": database_id, + "account_id": account_id, + } + }, + ) + + return do_mock + + +@fixture +def engine_by_db_url(settings: Settings, account_id: str) -> str: + return ( + f"https://{settings.server}" + f"{ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id)}" + ) + + @fixture def db_api_exceptions(): exceptions = { diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 8f3779ace9e..780b8692aca 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -152,13 +152,6 @@ def test_connect_engine_name( password="password", ) - with raises(ConfigurationError): - connect( - database="db", - username="username", - password="password", - ) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(query_callback, url=query_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) @@ -186,6 +179,48 @@ def test_connect_engine_name( assert connection.cursor().execute("select*") == len(python_query_data) +def test_connect_default_engine( + settings: Settings, + db_name: str, + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + query_callback: Callable, + query_url: str, + account_id_url: str, + account_id_callback: Callable, + engine_id: str, + get_engine_url: str, + get_engine_callback: Callable, + database_by_name_url: str, + database_by_name_callback: Callable, + database_id: str, + engine_by_db_url: str, + python_query_data: List[List[ColType]], + account_id: str, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(query_callback, url=query_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}" + + httpx_mock.add_response( + url=engine_by_db_url, + status_code=codes.OK, + json={ + "engine_url": settings.server, + }, + ) + with connect( + database=db_name, + username="u", + password="p", + account_name=settings.account_name, + api_endpoint=settings.server, + ) as connection: + assert connection.cursor().execute("select*") == len(python_query_data) + + def test_connection_unclosed_warnings(): c = Connection("", "", ("", ""), "") with warns(UserWarning) as winfo: