Skip to content
40 changes: 25 additions & 15 deletions src/firebolt/service/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def get_id_by_name(self, name: str) -> str:

def get_many(
self,
name_contains: str,
attached_engine_name_eq: str,
attached_engine_name_contains: str,
order_by: Union[str, DatabaseOrder],
name_contains: Optional[str] = None,
attached_engine_name_eq: Optional[str] = None,
attached_engine_name_contains: Optional[str] = None,
order_by: Optional[Union[str, DatabaseOrder]] = None,
) -> List[Database]:
"""
Get a list of databases on Firebolt.
Expand All @@ -63,20 +63,30 @@ def get_many(
A list of databases matching the filters.
"""

if isinstance(order_by, str):
order_by = DatabaseOrder[order_by]
params = {"page.first": "1000"}
if order_by:
if isinstance(order_by, str):
order_by = DatabaseOrder[order_by]
params["order_by"] = order_by.name

if name_contains:
params["filter.name_contains"] = name_contains

if attached_engine_name_eq:
params["filter.attached_engine_name_eq"] = attached_engine_name_eq

if attached_engine_name_contains:
params[
"filter.attached_engine_name_contains"
] = attached_engine_name_contains

response = self.client.get(
url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id),
params={
"filter.name_contains": name_contains,
"filter.attached_engine_name_eq": attached_engine_name_eq,
"filter.attached_engine_name_contains": attached_engine_name_contains,
"order_by": order_by.name,
},
url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), params=params
)

return [
Database.parse_obj_with_service(obj=d, database_service=self)
for d in response.json()["databases"]
Database.parse_obj_with_service(obj=d["node"], database_service=self)
for d in response.json()["edges"]
]

def create(
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from logging import getLogger
from os import environ

from pytest import fixture

from firebolt.service.manager import Settings

LOGGER = getLogger(__name__)

ENGINE_URL_ENV = "ENGINE_URL"
ENGINE_NAME_ENV = "ENGINE_NAME"
STOPPED_ENGINE_URL_ENV = "STOPPED_ENGINE_URL"
STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME"
DATABASE_NAME_ENV = "DATABASE_NAME"
USER_NAME_ENV = "USER_NAME"
PASSWORD_ENV = "PASSWORD"
ACCOUNT_NAME_ENV = "ACCOUNT_NAME"
API_ENDPOINT_ENV = "API_ENDPOINT"


def must_env(var_name: str) -> str:
assert var_name in environ, f"Expected {var_name} to be provided in environment"
LOGGER.info(f"{var_name}: {environ[var_name]}")
return environ[var_name]


@fixture(scope="session")
def rm_settings(api_endpoint, username, password) -> Settings:
return Settings(
server=api_endpoint,
user=username,
password=password,
default_region="us-east-1",
)


@fixture(scope="session")
def engine_url() -> str:
return must_env(ENGINE_URL_ENV)


@fixture(scope="session")
def stopped_engine_url() -> str:
return must_env(STOPPED_ENGINE_URL_ENV)


@fixture(scope="session")
def engine_name() -> str:
return must_env(ENGINE_NAME_ENV)


@fixture(scope="session")
def stopped_engine_name() -> str:
return must_env(STOPPED_ENGINE_URL_ENV)


@fixture(scope="session")
def database_name() -> str:
return must_env(DATABASE_NAME_ENV)


@fixture(scope="session")
def username() -> str:
return must_env(USER_NAME_ENV)


@fixture(scope="session")
def password() -> str:
return must_env(PASSWORD_ENV)


@fixture(scope="session")
def account_name() -> str:
return must_env(ACCOUNT_NAME_ENV)


@fixture(scope="session")
def api_endpoint() -> str:
return must_env(API_ENDPOINT_ENV)
62 changes: 0 additions & 62 deletions tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import date, datetime
from logging import getLogger
from os import environ
from typing import List

from pytest import fixture
Expand All @@ -11,67 +10,6 @@

LOGGER = getLogger(__name__)

ENGINE_URL_ENV = "ENGINE_URL"
ENGINE_NAME_ENV = "ENGINE_NAME"
STOPPED_ENGINE_URL_ENV = "STOPPED_ENGINE_URL"
STOPPED_ENGINE_NAME_ENV = "STOPPED_ENGINE_NAME"
DATABASE_NAME_ENV = "DATABASE_NAME"
USER_NAME_ENV = "USER_NAME"
PASSWORD_ENV = "PASSWORD"
ACCOUNT_NAME_ENV = "ACCOUNT_NAME"
API_ENDPOINT_ENV = "API_ENDPOINT"


def must_env(var_name: str) -> str:
assert var_name in environ, f"Expected {var_name} to be provided in environment"
LOGGER.info(f"{var_name}: {environ[var_name]}")
return environ[var_name]


@fixture(scope="session")
def engine_url() -> str:
return must_env(ENGINE_URL_ENV)


@fixture(scope="session")
def stopped_engine_url() -> str:
return must_env(STOPPED_ENGINE_URL_ENV)


@fixture(scope="session")
def engine_name() -> str:
return must_env(ENGINE_NAME_ENV)


@fixture(scope="session")
def stopped_engine_name() -> str:
return must_env(STOPPED_ENGINE_URL_ENV)


@fixture(scope="session")
def database_name() -> str:
return must_env(DATABASE_NAME_ENV)


@fixture(scope="session")
def username() -> str:
return must_env(USER_NAME_ENV)


@fixture(scope="session")
def password() -> str:
return must_env(PASSWORD_ENV)


@fixture(scope="session")
def account_name() -> str:
return must_env(ACCOUNT_NAME_ENV)


@fixture(scope="session")
def api_endpoint() -> str:
return must_env(API_ENDPOINT_ENV)


@fixture
def all_types_query() -> str:
Expand Down
26 changes: 25 additions & 1 deletion tests/integration/resource_manager/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from firebolt.service.manager import ResourceManager
from firebolt.service.manager import ResourceManager, Settings
from firebolt.service.types import EngineStatusSummary


Expand Down Expand Up @@ -47,3 +47,27 @@ def test_copy_engine():
engine_revision=rm.engine_revisions.get_by_key(engine.latest_revision_key),
)
assert engine_copy


def test_databases_get_many(rm_settings: Settings, database_name, engine_name):
rm = ResourceManager(rm_settings)

# get all databases, at least one should be returned
databases = rm.databases.get_many()
assert len(databases) > 0
assert database_name in {db.name for db in databases}

# get all databases, with name_contains
databases = rm.databases.get_many(name_contains=database_name)
assert len(databases) > 0
assert database_name in {db.name for db in databases}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also test passing attached_engine_name_eq and attached_engine_name_contains since we know it has engines.
Also, do you think we can test order_by?
We can just see how many databases there are from the first call, and if it's more that one we also try passing order_by

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests for attached_engine_name_eq and attached_engine_name_contains. However, I would not do additional tests on order_by, since it looks likes we are re-implementing order_by functionality in the test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still do some basic checks on order_by without re-implementing the functionality. e.g. checking that the parameter is propagated correctly and is added to the http request.


# get all databases, with name_contains
databases = rm.databases.get_many(attached_engine_name_eq=engine_name)
assert len(databases) > 0
assert database_name in {db.name for db in databases}

# get all databases, with name_contains
databases = rm.databases.get_many(attached_engine_name_contains=engine_name)
assert len(databases) > 0
assert database_name in {db.name for db in databases}
12 changes: 12 additions & 0 deletions tests/unit/service/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,18 @@ def do_mock(
return do_mock


@pytest.fixture
def databases_get_callback(databases_url: str, mock_database) -> Callable:
def get_databases_callback_inner(
request: httpx.Request = None, **kwargs
) -> Response:
return to_response(
status_code=httpx.codes.OK, json={"edges": [{"node": mock_database.dict()}]}
)

return get_databases_callback_inner


@pytest.fixture
def databases_url(settings: Settings, account_id: str) -> str:
return f"https://{settings.server}" + ACCOUNT_DATABASES_URL.format(
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/service/test_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Callable

from pytest_httpx import HTTPXMock
Expand Down Expand Up @@ -64,3 +65,40 @@ def test_database_get_by_name(
database = manager.databases.get_by_name(name=mock_database.name)

assert database.name == mock_database.name


def test_database_get_many(
httpx_mock: HTTPXMock,
auth_callback: Callable,
auth_url: str,
provider_callback: Callable,
provider_url: str,
settings: Settings,
account_id_callback: Callable,
account_id_url: str,
database_get_by_name_callback: Callable,
database_get_by_name_url: str,
databases_get_callback: Callable,
databases_url: str,
mock_database: Database,
):

httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(provider_callback, url=provider_url)
httpx_mock.add_callback(account_id_callback, url=account_id_url)
httpx_mock.add_callback(auth_callback, url=auth_url)
httpx_mock.add_callback(
databases_get_callback,
url=re.compile(databases_url + "?[a-zA-Z0-9=&]*"),
method="GET",
)

manager = ResourceManager(settings=settings)
databases = manager.databases.get_many(
name_contains=mock_database.name,
attached_engine_name_eq="mockengine",
attached_engine_name_contains="mockengine",
)

assert len(databases) == 1
assert databases[0].name == mock_database.name