Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions src/firebolt/async_db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
FireboltEngineError,
InterfaceError,
)
from firebolt.common.urls import ACCOUNT_ENGINE_BY_NAME_URL, ACCOUNT_ENGINE_URL
from firebolt.common.urls import (
ACCOUNT_ENGINE_BY_NAME_URL,
ACCOUNT_ENGINE_URL,
ACCOUNT_ENGINE_URL_BY_DATABASE_NAME,
)
from firebolt.common.util import fix_url_schema

DEFAULT_TIMEOUT_SECONDS: int = 5
Expand Down Expand Up @@ -65,17 +69,43 @@ async def _resolve_engine_url(
raise InterfaceError(f"Unable to retrieve engine endpoint: {e}.")


async def _get_database_default_engine_url(
database: str,
auth: AuthTypes,
api_endpoint: str,
account_name: Optional[str] = None,
) -> str:
async with AsyncClient(
auth=auth,
base_url=api_endpoint,
account_name=account_name,
api_endpoint=api_endpoint,
) as client:
try:
account_id = await client.account_id
response = await client.get(
url=ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id),
params={"database_name": database},
)
response.raise_for_status()
return response.json()["engine_url"]
except (
JSONDecodeError,
RequestError,
RuntimeError,
HTTPStatusError,
KeyError,
) as e:
raise InterfaceError(f"Unable to retrieve default engine endpoint: {e}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the real error is wrong credentials or expired/incorrect token?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we could also catch AuthenticationError

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other side, I think AuthenticationError is pretty representative itself, so I think we should just let it propagate further



def _validate_engine_name_and_url(
engine_name: Optional[str], engine_url: Optional[str]
) -> None:
if engine_name and engine_url:
raise ConfigurationError(
"Both engine_name and engine_url are provided. Provide only one to connect."
)
if not engine_name and not engine_url:
raise ConfigurationError(
"Neither engine_name nor engine_url is provided. Provide one to connect."
)


def _get_auth(
Expand Down Expand Up @@ -120,7 +150,7 @@ async def connect_inner(
api_endpoint(optional): Firebolt API endpoint. Used for authentication.

Note:
Either `engine_name` or `engine_url` should be provided, but not both.
Providing both `engine_name` and `engine_url` would result in an error.

"""
# These parameters are optional in function signature
Expand All @@ -136,7 +166,15 @@ async def connect_inner(
# Mypy checks, this should never happen
assert database is not None

if engine_name:
if not engine_name and not engine_url:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the documentation on L153 to reflect the changes

engine_url = await _get_database_default_engine_url(
database=database,
auth=auth,
account_name=account_name,
api_endpoint=api_endpoint,
)

elif engine_name:
engine_url = await _resolve_engine_url(
engine_name=engine_name,
auth=auth,
Expand Down
1 change: 1 addition & 0 deletions src/firebolt/common/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ACCOUNT_ENGINES_URL = "/core/v1/accounts/{account_id}/engines"
ACCOUNT_ENGINE_BY_NAME_URL = ACCOUNT_ENGINES_URL + ":getIdByName"
ACCOUNT_ENGINE_REVISION_URL = ACCOUNT_ENGINE_URL + "/engineRevisions/{revision_id}"
ACCOUNT_ENGINE_URL_BY_DATABASE_NAME = ACCOUNT_ENGINES_URL + ":getURLByDatabaseName"

ACCOUNT_DATABASES_URL = "/core/v1/accounts/{account_id}/databases"
ACCOUNT_DATABASE_URL = "/core/v1/accounts/{account_id}/databases/{database_id}"
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/dbapi/async/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,22 @@ async def connection_engine_name(
api_endpoint=api_endpoint,
) as connection:
yield connection


@fixture
async def connection_no_engine(
database_name: str,
username: str,
password: str,
account_name: str,
api_endpoint: str,
) -> Connection:

async with await connect(
database=database_name,
username=username,
password=password,
account_name=account_name,
api_endpoint=api_endpoint,
) as connection:
yield connection
16 changes: 16 additions & 0 deletions tests/integration/dbapi/async/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ async def test_connect_engine_name(
)


@mark.asyncio
async def test_connect_no_engine(
connection_no_engine: Connection,
all_types_query: str,
all_types_query_description: List[Column],
all_types_query_response: List[ColType],
) -> None:
"""Connecting with engine name is handled properly."""
await test_select(
connection_no_engine,
all_types_query,
all_types_query_description,
all_types_query_response,
)


@mark.asyncio
async def test_select(
connection: Connection,
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/dbapi/sync/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,22 @@ def connection_engine_name(
)
yield connection
connection.close()


@fixture
def connection_no_engine(
database_name: str,
username: str,
password: str,
account_name: str,
api_endpoint: str,
) -> Connection:
connection = connect(
database=database_name,
username=username,
password=password,
account_name=account_name,
api_endpoint=api_endpoint,
)
yield connection
connection.close()
15 changes: 15 additions & 0 deletions tests/integration/dbapi/sync/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def test_connect_engine_name(
)


def test_connect_no_engine(
connection_no_engine: Connection,
all_types_query: str,
all_types_query_description: List[Column],
all_types_query_response: List[ColType],
) -> None:
"""Connecting with engine name is handled properly."""
test_select(
connection_no_engine,
all_types_query,
all_types_query_description,
all_types_query_response,
)


def test_select(
connection: Connection,
all_types_query: str,
Expand Down
54 changes: 43 additions & 11 deletions tests/unit/async_db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ async def test_connect_engine_name(
engine_id: str,
get_engine_url: str,
get_engine_callback: Callable,
get_providers_url: str,
get_providers_callback: Callable,
python_query_data: List[List[ColType]],
account_id: str,
):
Expand All @@ -167,15 +165,6 @@ async def test_connect_engine_name(
):
pass

with raises(ConfigurationError):
async with await connect(
database="db",
username="username",
password="password",
account_name="account",
):
pass

httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)
Expand Down Expand Up @@ -222,6 +211,49 @@ async def test_connect_engine_name(
assert await connection.cursor().execute("select*") == len(python_query_data)


@mark.asyncio
async def test_connect_default_engine(
settings: Settings,
db_name: str,
httpx_mock: HTTPXMock,
auth_callback: Callable,
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
get_engine_callback: Callable,
database_by_name_url: str,
database_by_name_callback: Callable,
database_id: str,
engine_by_db_url: str,
python_query_data: List[List[ColType]],
account_id: str,
):
httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)
engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}"

httpx_mock.add_response(
url=engine_by_db_url,
status_code=codes.OK,
json={
"engine_url": settings.server,
},
)
async with await connect(
database=db_name,
username="u",
password="p",
account_name=settings.account_name,
api_endpoint=settings.server,
) as connection:
assert await connection.cursor().execute("select*") == len(python_query_data)


@mark.asyncio
async def test_connection_commit(connection: Connection):
# nothing happens
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from firebolt.common.settings import Settings
from firebolt.common.urls import (
ACCOUNT_BY_NAME_URL,
ACCOUNT_DATABASE_BY_NAME_URL,
ACCOUNT_ENGINE_URL,
ACCOUNT_ENGINE_URL_BY_DATABASE_NAME,
ACCOUNT_URL,
AUTH_URL,
DATABASES_URL,
Expand Down Expand Up @@ -242,6 +244,47 @@ def get_databases_url(settings: Settings) -> str:
return f"https://{settings.server}{DATABASES_URL}"


@fixture
def database_id() -> str:
return "database_id"


@fixture
def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> str:
return (
f"https://{settings.server}"
f"{ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id)}"
f"?database_name={db_name}"
)


@fixture
def database_by_name_callback(account_id: str, database_id: str) -> str:
def do_mock(
request: httpx.Request = None,
**kwargs,
) -> Response:
return Response(
status_code=httpx.codes.OK,
json={
"database_id": {
"database_id": database_id,
"account_id": account_id,
}
},
)

return do_mock


@fixture
def engine_by_db_url(settings: Settings, account_id: str) -> str:
return (
f"https://{settings.server}"
f"{ACCOUNT_ENGINE_URL_BY_DATABASE_NAME.format(account_id=account_id)}"
)


@fixture
def db_api_exceptions():
exceptions = {
Expand Down
49 changes: 42 additions & 7 deletions tests/unit/db/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,6 @@ def test_connect_engine_name(
password="password",
)

with raises(ConfigurationError):
connect(
database="db",
username="username",
password="password",
)

httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)
Expand Down Expand Up @@ -186,6 +179,48 @@ def test_connect_engine_name(
assert connection.cursor().execute("select*") == len(python_query_data)


def test_connect_default_engine(
settings: Settings,
db_name: str,
httpx_mock: HTTPXMock,
auth_callback: Callable,
auth_url: str,
query_callback: Callable,
query_url: str,
account_id_url: str,
account_id_callback: Callable,
engine_id: str,
get_engine_url: str,
get_engine_callback: Callable,
database_by_name_url: str,
database_by_name_callback: Callable,
database_id: str,
engine_by_db_url: str,
python_query_data: List[List[ColType]],
account_id: str,
):
httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(query_callback, url=query_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)
engine_by_db_url = f"{engine_by_db_url}?database_name={db_name}"

httpx_mock.add_response(
url=engine_by_db_url,
status_code=codes.OK,
json={
"engine_url": settings.server,
},
)
with connect(
database=db_name,
username="u",
password="p",
account_name=settings.account_name,
api_endpoint=settings.server,
) as connection:
assert connection.cursor().execute("select*") == len(python_query_data)


def test_connection_unclosed_warnings():
c = Connection("", "", ("", ""), "")
with warns(UserWarning) as winfo:
Expand Down