From 284ca08d7acbc67ba95b2fad9e5bbd3f5d78f6de Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 12 Jun 2023 14:12:05 +0300 Subject: [PATCH 01/18] update resource manager --- src/firebolt/model/engine.py | 4 +- src/firebolt/service/base.py | 7 +-- src/firebolt/service/manager.py | 98 +++++++++++++++++++++++++++++---- src/firebolt/service/region.py | 6 +- 4 files changed, 96 insertions(+), 19 deletions(-) diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 04c9ef378d9..6b6ebbc6f8a 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -197,8 +197,8 @@ def get_connection(self) -> Connection: # we always have firebolt Auth as a client auth auth=self._service.client.auth, # type: ignore engine_name=self.name, - account_name=self._service.settings.account_name, - api_endpoint=self._service.settings.server, + account_name=self._service.resource_manager.account_name, + api_endpoint=self._service.resource_manager.api_endpoint, ) @check_attached_to_database diff --git a/src/firebolt/service/base.py b/src/firebolt/service/base.py index d6617a43c34..079a2a18f5d 100644 --- a/src/firebolt/service/base.py +++ b/src/firebolt/service/base.py @@ -1,5 +1,4 @@ from firebolt.client import Client -from firebolt.common import Settings from firebolt.service.manager import ResourceManager @@ -9,12 +8,12 @@ def __init__(self, resource_manager: ResourceManager): @property def client(self) -> Client: - return self.resource_manager.client + return self.resource_manager._client @property def account_id(self) -> str: return self.resource_manager.account_id @property - def settings(self) -> Settings: - return self.resource_manager.settings + def _default_region(self) -> str: + return self.resource_manager.default_region diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index c8d646bfb78..f461ac07a1d 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -1,14 +1,32 @@ +import logging from typing import Optional from httpx import Timeout -from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx +from firebolt.client import ( + DEFAULT_API_URL, + Auth, + Client, + log_request, + log_response, + raise_on_4xx_5xx, +) from firebolt.common import Settings +from firebolt.db import connect from firebolt.service.provider import get_provider_id from firebolt.utils.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 60 * 2 +logger = logging.getLogger(__name__) + +SETTINGS_DEPRECATION_MESSAGE = """ +Using Settings objects for ResourceManager intialization is deprecated. +Please pass parameters directly +Example: + >>> rm = ResourceManager(auth=ClientCredentials(..), default_region="us-east-1", ..) +""" + class ResourceManager: """ @@ -25,20 +43,80 @@ class ResourceManager: - instance types (AWS instance types which engines can use) """ - def __init__(self, settings: Optional[Settings] = None): - self.settings = settings or Settings() - self.client = Client( - auth=self.settings.auth, - base_url=fix_url_schema(self.settings.server), - account_name=self.settings.account_name, - api_endpoint=self.settings.server, + __slots__ = ( + "account_name", + "account_id", + "api_endpoint", + "default_region", + "_client", + "_connection", + "regions", + "instance_types", + "_provider_id", + "databases", + "engines", + "engine_revisions", + "bindings", + ) + + def __init__( + self, + settings: Optional[Settings] = None, + auth: Optional[Auth] = None, + account_name: Optional[str] = None, + default_region: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + ): + if settings: + logger.warning(SETTINGS_DEPRECATION_MESSAGE) + if ( + auth + or account_name + or default_region + or (api_endpoint != DEFAULT_API_URL) + ): + raise ValueError( + "Other ResourceManager parameters are not allowed " + "when Settings are provided" + ) + auth = settings.auth + account_name = settings.account_name + default_region = settings.default_region + api_endpoint = settings.server + + for param, name in ( + (auth, "auth"), + (account_name, "account_name"), + (default_region, "default_region"), + ): + if not param: + raise ValueError(f"Missing {name} value") + + # type checks + assert auth is not None + assert account_name is not None + assert default_region is not None + + self._client = Client( + auth=auth, + base_url=fix_url_schema(api_endpoint), + account_name=account_name, + api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), event_hooks={ "request": [log_request], "response": [raise_on_4xx_5xx, log_response], }, ) - self.account_id = self.client.account_id + self._connection = connect( + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + self.account_name = account_name + self.api_endpoint = api_endpoint + self.account_id = self._client.account_id + self.default_region = default_region self._init_services() def _init_services(self) -> None: @@ -53,7 +131,7 @@ def _init_services(self) -> None: # Cloud Platform Resources (AWS) self.regions = RegionService(resource_manager=self) self.instance_types = InstanceTypeService(resource_manager=self) - self.provider_id = get_provider_id(client=self.client) + self._provider_id = get_provider_id(client=self._client) # Firebolt Resources self.databases = DatabaseService(resource_manager=self) diff --git a/src/firebolt/service/region.py b/src/firebolt/service/region.py index ff614d56288..d79f73995c3 100644 --- a/src/firebolt/service/region.py +++ b/src/firebolt/service/region.py @@ -41,11 +41,11 @@ def regions_by_key(self) -> Dict[RegionKey, Region]: def default_region(self) -> Region: """Default AWS region, could be provided from environment.""" - if not self.settings.default_region: + if not self._default_region: raise ValueError( "The environment variable FIREBOLT_DEFAULT_REGION must be set." ) - return self.get_by_name(name=self.settings.default_region) + return self.get_by_name(name=self._default_region) def get_by_name(self, name: str) -> Region: """Get an AWS region by its name (eg. us-east-1).""" @@ -61,5 +61,5 @@ def get_by_id(self, id_: str) -> Region: """Get an AWS region by region_id.""" return self.get_by_key( - RegionKey(provider_id=self.resource_manager.provider_id, region_id=id_) + RegionKey(provider_id=self.resource_manager._provider_id, region_id=id_) ) From 7aa7fe5fa9774e0e351cc2b0580c858c491e2227 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 13 Jun 2023 11:15:51 +0300 Subject: [PATCH 02/18] fix unit tests --- src/firebolt/service/manager.py | 6 ++ tests/unit/db/conftest.py | 4 +- tests/unit/db/test_util.py | 4 ++ tests/unit/db_conftest.py | 5 +- tests/unit/service/test_database.py | 41 +++--------- tests/unit/service/test_engine.py | 70 ++++++--------------- tests/unit/service/test_instance_type.py | 10 +-- tests/unit/service/test_region.py | 10 +-- tests/unit/service/test_resource_manager.py | 41 +++++++----- 9 files changed, 72 insertions(+), 119 deletions(-) diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index f461ac07a1d..971d4056d92 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -138,3 +138,9 @@ def _init_services(self) -> None: self.engines = EngineService(resource_manager=self) self.engine_revisions = EngineRevisionService(resource_manager=self) self.bindings = BindingService(resource_manager=self) + + def __del__(self) -> None: + if hasattr(self, "_client"): + self._client.close() + if hasattr(self, "_connection"): + self._connection.close() diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index ec8d299f9f8..90bd7dc3080 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -32,9 +32,9 @@ def system_connection( db_name: str, auth: Auth, account_name: str, - mock_system_connection_flow: Callable, + mock_system_engine_connection_flow: Callable, ) -> Connection: - mock_system_connection_flow() + mock_system_engine_connection_flow() with connect( database=db_name, auth=auth, diff --git a/tests/unit/db/test_util.py b/tests/unit/db/test_util.py index a21ac69b73a..452c09e6430 100644 --- a/tests/unit/db/test_util.py +++ b/tests/unit/db/test_util.py @@ -45,11 +45,15 @@ def test_is_db_not_available( def test_is_engine_running_system( + httpx_mock: HTTPXMock, system_connection: Connection, ): # System engine is always running assert is_engine_running(system_connection, "dummy") == True + # We haven't used account id endpoint since we didn't run any query, ignoring it + httpx_mock.reset(False) + def test_is_engine_running( connection: Connection, diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 5ed5a951758..745eb9287f1 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -494,16 +494,19 @@ def inner() -> None: @fixture -def mock_system_connection_flow( +def mock_system_engine_connection_flow( httpx_mock: HTTPXMock, auth_url: str, check_credentials_callback: Callable, get_system_engine_url: str, get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ) -> Callable: def inner() -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) return inner diff --git a/tests/unit/service/test_database.py b/tests/unit/service/test_database.py index 35dec331ad6..01246743a19 100644 --- a/tests/unit/service/test_database.py +++ b/tests/unit/service/test_database.py @@ -1,4 +1,4 @@ -from re import Pattern, compile +from re import compile from typing import Callable from pytest_httpx import HTTPXMock @@ -10,24 +10,19 @@ def test_database_create( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, region_callback: Callable, region_url: str, settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, create_databases_callback: Callable, databases_url: str, db_name: str, db_description: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") @@ -40,24 +35,18 @@ def test_database_create( def test_database_get_by_name( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, database_get_by_name_callback: Callable, database_get_by_name_url: str, database_get_callback: Callable, database_get_url: str, mock_database: Database, + mock_system_engine_connection_flow: Callable, ): - - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(database_get_by_name_callback, url=database_get_by_name_url) httpx_mock.add_callback(database_get_callback, url=database_get_url) @@ -69,24 +58,18 @@ def test_database_get_by_name( def test_database_get_many( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, database_get_by_name_callback: Callable, database_get_by_name_url: str, databases_get_callback: Callable, databases_url: str, mock_database: Database, + mock_system_engine_connection_flow: Callable, ): - - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( databases_get_callback, url=compile(databases_url + "?[a-zA-Z0-9=&]*"), @@ -106,26 +89,22 @@ def test_database_get_many( def test_database_update( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, database_update_callback: Callable, database_url: str, mock_database: Database, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(database_update_callback, url=database_url, method="PATCH") manager = ResourceManager(settings=settings) - mock_database._service = manager + mock_database._service = manager.databases database = mock_database.update(description="new description") assert database.description == "new description" diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 9686ecbf7ff..7abde2d8008 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -16,8 +16,6 @@ def test_engine_create( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -33,14 +31,13 @@ def test_engine_create( account_id_url: Pattern, engine_callback: Callable, engine_url: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") @@ -52,8 +49,6 @@ def test_engine_create( def test_engine_create_with_kwargs( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -69,14 +64,13 @@ def test_engine_create_with_kwargs( engine_url: str, account_id: str, mock_engine_revision: EngineRevision, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) # Setting to manager.engines.create defaults mock_engine.key = None @@ -111,8 +105,6 @@ def test_engine_create_with_kwargs( def test_engine_create_with_kwargs_fail( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -123,14 +115,13 @@ def test_engine_create_with_kwargs_fail( engine_name: str, account_id_callback: Callable, account_id_url: Pattern, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) manager = ResourceManager(settings=settings) @@ -149,8 +140,6 @@ def test_engine_create_with_kwargs_fail( def test_engine_create_no_available_types( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_empty_callback: Callable, @@ -162,14 +151,13 @@ def test_engine_create_no_available_types( account_id_url: Pattern, engine_url: str, region_2: Region, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_empty_callback, url=instance_type_region_2_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) manager = ResourceManager(settings=settings) @@ -179,8 +167,6 @@ def test_engine_create_no_available_types( def test_engine_no_attached_database( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -202,14 +188,13 @@ def test_engine_no_attached_database( database_url: str, no_bindings_callback: Callable, bindings_url: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") httpx_mock.add_callback(no_bindings_callback, url=bindings_url) @@ -223,8 +208,6 @@ def test_engine_no_attached_database( def test_engine_start_binding_to_missing_database( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -244,14 +227,13 @@ def test_engine_start_binding_to_missing_database( database_url: str, bindings_callback: Callable, bindings_url: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") httpx_mock.add_callback(bindings_callback, url=bindings_url) @@ -266,8 +248,6 @@ def test_engine_start_binding_to_missing_database( def test_get_connection( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -294,7 +274,6 @@ def test_get_connection( httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(region_callback, url=region_url) httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") httpx_mock.add_callback(bindings_callback, url=bindings_url) @@ -311,8 +290,6 @@ def test_get_connection( def test_attach_to_database( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, region_callback: Callable, @@ -336,14 +313,13 @@ def test_attach_to_database( create_binding_url: str, bindings_callback: Callable, bindings_url: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(bindings_callback, url=bindings_url) httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") httpx_mock.add_callback(database_not_found_callback, url=database_url, method="GET") @@ -369,8 +345,6 @@ def test_attach_to_database( def test_engine_update( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_region_1_callback: Callable, @@ -388,14 +362,10 @@ def test_engine_update( engine_url: str, account_engine_url: str, account_engine_callback: Callable, + mock_system_engine_connection_flow: Callable, ): - - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - # httpx_mock.add_callback( account_engine_callback, url=account_engine_url, method="PATCH" ) @@ -412,8 +382,6 @@ def test_engine_update( def test_engine_restart( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, settings: Settings, @@ -426,13 +394,11 @@ def test_engine_restart( bindings_url: str, database_callback: Callable, database_url: str, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback( engine_callback, url=f"{account_engine_url}:restart", method="POST" ) diff --git a/tests/unit/service/test_instance_type.py b/tests/unit/service/test_instance_type.py index dce917c1f42..ef961ce9212 100644 --- a/tests/unit/service/test_instance_type.py +++ b/tests/unit/service/test_instance_type.py @@ -1,4 +1,3 @@ -from re import Pattern from typing import Callable, List from pytest_httpx import HTTPXMock @@ -11,8 +10,6 @@ def test_instance_type( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, instance_type_callback: Callable, @@ -21,17 +18,15 @@ def test_instance_type( instance_type_url: str, instance_type_region_1_url: str, instance_type_region_2_url: str, - account_id_callback: Callable, - account_id_url: Pattern, settings: Settings, mock_instance_types: List[InstanceType], cheapest_instance: InstanceType, region_1: Region, region_2: Region, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(instance_type_callback, url=instance_type_url) httpx_mock.add_callback( instance_type_region_1_callback, url=instance_type_region_1_url @@ -39,7 +34,6 @@ def test_instance_type( httpx_mock.add_callback( instance_type_empty_callback, url=instance_type_region_2_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) manager = ResourceManager(settings=settings) assert manager.instance_types.instance_types == mock_instance_types diff --git a/tests/unit/service/test_region.py b/tests/unit/service/test_region.py index 939238f6bfe..6e6ebcda750 100644 --- a/tests/unit/service/test_region.py +++ b/tests/unit/service/test_region.py @@ -1,4 +1,3 @@ -from re import Pattern from typing import Callable, List from pytest_httpx import HTTPXMock @@ -10,22 +9,17 @@ def test_region( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, provider_callback: Callable, provider_url: str, region_callback: Callable, region_url: str, - account_id_callback: Callable, - account_id_url: Pattern, settings: Settings, mock_regions: List[Region], + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) manager = ResourceManager(settings=settings) assert manager.regions.regions == mock_regions diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 9ddfeacb468..d07c23044db 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -10,6 +10,7 @@ from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError from firebolt.utils.token_storage import TokenSecureStorage +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME def test_rm_credentials( @@ -23,17 +24,17 @@ def test_rm_credentials( provider_callback: Callable, provider_url: str, access_token: str, + mock_system_engine_connection_flow: Callable, ) -> None: """Credentials, that are passed to rm are processed properly.""" url = "https://url" - httpx_mock.add_callback(check_credentials_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) rm = ResourceManager(settings) - rm.client.get(url) + rm._client.get(url) @mark.nofakefs @@ -49,14 +50,14 @@ def test_rm_token_cache( provider_callback: Callable, provider_url: str, access_token: str, + mock_system_engine_connection_flow: Callable, ) -> None: """Credentials, that are passed to rm are cached properly.""" url = "https://url" - httpx_mock.add_callback(check_credentials_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) with Patcher(): local_settings = Settings( @@ -70,7 +71,7 @@ def test_rm_token_cache( default_region=settings.default_region, ) rm = ResourceManager(local_settings) - rm.client.get(url) + rm._client.get(url) ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" @@ -88,7 +89,7 @@ def test_rm_token_cache( default_region=settings.default_region, ) rm = ResourceManager(local_settings) - rm.client.get(url) + rm._client.get(url) ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) assert ( @@ -99,22 +100,28 @@ def test_rm_token_cache( def test_rm_invalid_account_name( httpx_mock: HTTPXMock, auth: Auth, - settings: Settings, - check_credentials_callback: Callable, + server: str, + region_1: str, auth_url: str, + check_credentials_callback: Callable, account_id_url: Pattern, account_id_callback: Callable, + get_system_engine_callback: Callable, ) -> None: """Resource manager raises an error on invalid account name.""" + get_system_engine_url = ( + f"https://{server}" + f"{GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name='invalid')}" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - local_settings = Settings( - auth=auth, - account_name="invalid", - server=settings.server, - default_region=settings.default_region, - ) - with raises(AccountNotFoundError): - ResourceManager(local_settings) + ResourceManager( + auth=auth, + account_name="invalid", + api_endpoint=server, + default_region=region_1, + ) From 4961ec89ab04efca78997f2e592d5951f462c51e Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 14 Jun 2023 14:33:55 +0300 Subject: [PATCH 03/18] upgrade engines service --- src/firebolt/model/__init__.py | 38 +- src/firebolt/model/database.py | 95 ++--- src/firebolt/model/engine.py | 526 ++++++------------------ src/firebolt/model/engine_revision.py | 58 --- src/firebolt/model/instance_type.py | 2 +- src/firebolt/model/provider.py | 16 +- src/firebolt/model/region.py | 4 +- src/firebolt/service/base.py | 13 +- src/firebolt/service/binding.py | 139 ------- src/firebolt/service/database.py | 33 +- src/firebolt/service/engine.py | 265 +++++------- src/firebolt/service/engine_revision.py | 39 -- src/firebolt/service/instance_type.py | 4 +- src/firebolt/service/provider.py | 2 +- src/firebolt/service/region.py | 2 +- src/firebolt/service/types.py | 217 +--------- src/firebolt/utils/exception.py | 14 + 17 files changed, 367 insertions(+), 1100 deletions(-) delete mode 100644 src/firebolt/model/engine_revision.py delete mode 100644 src/firebolt/service/binding.py delete mode 100644 src/firebolt/service/engine_revision.py diff --git a/src/firebolt/model/__init__.py b/src/firebolt/model/__init__.py index 83992de4b6e..28b6fc1f9ac 100644 --- a/src/firebolt/model/__init__.py +++ b/src/firebolt/model/__init__.py @@ -1,24 +1,28 @@ import json -from typing import Any +from dataclasses import dataclass, field, fields +from typing import ClassVar, Dict, Optional, Type, TypeVar -from pydantic import BaseModel +from firebolt.service.base import BaseService +Model = TypeVar("Model", bound="FireboltBaseModel") -class FireboltBaseModel(BaseModel): - class Config: - allow_population_by_field_name = True - extra = "forbid" - def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict: - """ - Generate a dictionary representation of the service that contains serialized - primitive types, and is therefore JSON-ready. +@dataclass +class FireboltBaseModel: + _service: BaseService = field() - This could be replaced with something native once this issue is resolved: - https://github.com/samuelcolvin/pydantic/issues/1409 + @classmethod + def _get_field_overrides(cls) -> Dict[str, str]: + return { + f.metadata["db_name"]: f.name + for f in fields(cls) + if "db_name" in f.metadata + } - This function is intended to improve the compatibility with HTTPX, which - expects to take in a dictionary of primitives as input to the JSON parameter - of its request function. See: https://www.python-httpx.org/api/#helper-functions - """ - return json.loads(self.json(*args, **kwargs)) + @classmethod + def _from_dict( + cls: Type[Model], data: dict, service: Optional[BaseService] = None + ) -> Model: + data["_service"] = service + field_name_overrides = cls._get_field_overrides() + return cls(**{field_name_overrides.get(k, k): v for k, v in data.items()}) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 9cc92264c28..dc778ab0f63 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -1,14 +1,15 @@ from __future__ import annotations import logging +from dataclasses import asdict, dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, List, Optional, Sequence -from pydantic import Field, PrivateAttr +from pydantic import Field from firebolt.model import FireboltBaseModel from firebolt.model.region import RegionKey -from firebolt.service.types import EngineStatusSummary +from firebolt.service.types import EngineStatus from firebolt.utils.exception import AttachedEngineInUseError from firebolt.utils.urls import ACCOUNT_DATABASE_URL @@ -20,15 +21,18 @@ logger = logging.getLogger(__name__) -class DatabaseKey(FireboltBaseModel): +@dataclass +class DatabaseKey: account_id: str database_id: str -class FieldMask(FireboltBaseModel): +@dataclass +class FieldMask: paths: Sequence[str] = Field(alias="paths") +@dataclass class Database(FireboltBaseModel): """ A Firebolt database. @@ -38,35 +42,27 @@ class Database(FireboltBaseModel): """ # internal - _service: DatabaseService = PrivateAttr() + _service: DatabaseService = field() # required - name: str = Field(min_length=1, max_length=255, regex=r"^[0-9a-zA-Z_]+$") - compute_region_key: RegionKey = Field(alias="compute_region_id") + name: str = field(metadata={"db_name": "database_name"}) + description: str = field() + compute_region_key: RegionKey = field() # optional - database_key: Optional[DatabaseKey] = Field(None, alias="id") - description: Optional[str] = Field(None, max_length=255) - emoji: Optional[str] = Field(None, max_length=255) - current_status: Optional[str] - health_status: Optional[str] - data_size_full: Optional[int] - data_size_compressed: Optional[int] - is_system_database: Optional[bool] - storage_bucket_name: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - - @classmethod - def parse_obj_with_service( - cls, obj: Any, database_service: DatabaseService - ) -> Database: - database = cls.parse_obj(obj) - database._service = database_service - return database + database_key: Optional[DatabaseKey] = field(default=None) + emoji: Optional[str] = field(default=None) + current_status: Optional[str] = field(default=None) + health_status: Optional[str] = field(default=None) + data_size_full: Optional[int] = field(default=None) + data_size_compressed: Optional[int] = field(default=None) + is_system_database: Optional[bool] = field(default=None) + storage_bucket_name: Optional[str] = field(default=None) + create_time: Optional[datetime] = field(default=None) + create_actor: Optional[str] = field(default=None) + last_update_time: Optional[datetime] = field(default=None) + last_update_actor: Optional[str] = field(default=None) + desired_status: Optional[str] = field(default=None) @property def database_id(self) -> Optional[str]: @@ -107,9 +103,9 @@ def delete(self) -> Database: """ for engine in self.get_attached_engines(): - if engine.current_status_summary in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STARTING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING, + if engine.current_status in { + EngineStatus.STARTING, + EngineStatus.STOPPING, }: raise AttachedEngineInUseError(method_name="delete") @@ -122,22 +118,21 @@ def delete(self) -> Database: ), headers={"Content-type": "application/json"}, ) - return Database.parse_obj_with_service( - response.json()["database"], self._service - ) + return Database._from_dict(response.json()["database"], self._service) def update(self, description: str) -> Database: """ Updates a database description. """ - class _DatabaseUpdateRequest(FireboltBaseModel): + @dataclass + class _DatabaseUpdateRequest: """Helper model for sending Database creation requests.""" - account_id: str - database: Database - database_id: str - update_mask: FieldMask + account_id: str = field() + database: Database = field() + database_id: Optional[str] = field() + update_mask: FieldMask = field() self.description = description @@ -146,12 +141,14 @@ class _DatabaseUpdateRequest(FireboltBaseModel): f"name={self.name}, description={self.description})" ) - payload = _DatabaseUpdateRequest( - account_id=self._service.account_id, - database=self, - database_id=self.database_id, - update_mask=FieldMask(paths=["description"]), - ).jsonable_dict(by_alias=True) + payload = asdict( + _DatabaseUpdateRequest( + account_id=self._service.account_id, + database=self, + database_id=self.database_id, + update_mask=FieldMask(paths=["description"]), + ) + ) response = self._service.client.patch( url=ACCOUNT_DATABASE_URL.format( @@ -161,9 +158,7 @@ class _DatabaseUpdateRequest(FireboltBaseModel): json=payload, ) - return Database.parse_obj_with_service( - response.json()["database"], self._service - ) + return Database._from_dict(response.json()["database"], self._service) def get_default_engine(self) -> Optional[Engine]: """ diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 6b6ebbc6f8a..aa548bad7e8 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -3,31 +3,14 @@ import functools import logging import time -from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence - -from pydantic import Field, PrivateAttr +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple from firebolt.db import Connection, connect from firebolt.model import FireboltBaseModel -from firebolt.model.binding import Binding from firebolt.model.database import Database -from firebolt.model.engine_revision import EngineRevision, EngineRevisionKey -from firebolt.model.region import RegionKey -from firebolt.service.types import ( - EngineStatus, - EngineStatusSummary, - EngineType, - WarmupMethod, -) -from firebolt.utils.exception import NoAttachedDatabaseError -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_RESTART_URL, - ACCOUNT_ENGINE_START_URL, - ACCOUNT_ENGINE_STOP_URL, - ACCOUNT_ENGINE_URL, -) -from firebolt.utils.util import prune_dict +from firebolt.service.types import EngineStatus, WarmupMethod +from firebolt.utils.exception import DatabaseNotFoundError if TYPE_CHECKING: from firebolt.service.engine import EngineService @@ -35,74 +18,19 @@ logger = logging.getLogger(__name__) -class EngineKey(FireboltBaseModel): - account_id: str - engine_id: str - - -def wait(seconds: int, timeout_time: float, error_message: str, verbose: bool) -> None: - time.sleep(seconds) - if time.time() > timeout_time: - raise TimeoutError(error_message) - if verbose: - print(".", end="") - - -class EngineSettings(FireboltBaseModel): - """ - Engine settings. - - See also: :py:class:`EngineRevisionSpecification - ` - which also contains engine configuration. - """ - - preset: str - auto_stop_delay_duration: str = Field(regex=r"^[0-9]+[sm]$|^0$") - minimum_logging_level: str - is_read_only: bool - warm_up: str - - @classmethod - def default( - cls, - engine_type: EngineType = EngineType.GENERAL_PURPOSE, - auto_stop_delay_duration: str = "1200s", - warm_up: WarmupMethod = WarmupMethod.PRELOAD_INDEXES, - minimum_logging_level: str = "ENGINE_SETTINGS_LOGGING_LEVEL_INFO", - ) -> EngineSettings: - if engine_type == EngineType.GENERAL_PURPOSE: - preset = engine_type.GENERAL_PURPOSE.api_settings_preset_name # type: ignore # noqa: E501 - is_read_only = False - else: - preset = engine_type.DATA_ANALYTICS.api_settings_preset_name # type: ignore - is_read_only = True - - return cls( - preset=preset, - auto_stop_delay_duration=auto_stop_delay_duration, - minimum_logging_level=minimum_logging_level, - is_read_only=is_read_only, - warm_up=warm_up.api_name, - ) - - def check_attached_to_database(func: Callable) -> Callable: """(Decorator) Ensure the engine is attached to a database.""" @functools.wraps(func) def inner(self: Engine, *args: Any, **kwargs: Any) -> Any: - if self.database is None: - raise NoAttachedDatabaseError(method_name=func.__name__) + # if self.database is None: + # raise NoAttachedDatabaseError(method_name=func.__name__) return func(self, *args, **kwargs) return inner -class FieldMask(FireboltBaseModel): - paths: Sequence[str] = Field(alias="paths") - - +@dataclass class Engine(FireboltBaseModel): """ A Firebolt engine. Responsible for performing work (queries, ingestion). @@ -113,76 +41,55 @@ class Engine(FireboltBaseModel): `. """ - # internal - _service: EngineService = PrivateAttr() - - # required - name: str = Field(min_length=1, max_length=255, regex=r"^[0-9a-zA-Z_]+$") - compute_region_key: RegionKey = Field(alias="compute_region_id") - settings: EngineSettings - - # optional - key: Optional[EngineKey] = Field(None, alias="id") - description: Optional[str] - emoji: Optional[str] - current_status: Optional[EngineStatus] - current_status_summary: Optional[EngineStatusSummary] - latest_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="latest_revision_id" - ) - endpoint: Optional[str] - endpoint_serving_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="endpoint_serving_revision_id" + START_SQL: ClassVar[str] = "START ENGINE {}" + STOP_SQL: ClassVar[str] = "STOP ENGINE {}" + ALTER_PREFIX_SQL: ClassVar[str] = "ALTER ENGINE {} SET " + ALTER_PARAMETER_NAMES: ClassVar[Tuple] = ( + "SCALE", + "SPEC", + "AUTO_STOP", + "RENAME_TO", + "WARMUP", ) - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - last_use_time: Optional[datetime] - desired_status: Optional[str] - health_status: Optional[str] - endpoint_desired_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="endpoint_desired_revision_id" - ) - - @classmethod - def parse_obj_with_service(cls, obj: Any, engine_service: EngineService) -> Engine: - engine = cls.parse_obj(obj) - engine._service = engine_service - return engine - - @property - def engine_id(self) -> str: - if self.key is None: - raise ValueError("engine key is None") - return self.key.engine_id + DROP_SQL: ClassVar[str] = "DROP ENGINE {}" + + _service: EngineService = field() + + name: str = field(metadata={"db_name": "engine_name"}) + region: str = field() + spec: str = field() + scale: int = field() + current_status: str = field(metadata={"db_name": "status"}) + _database_name: Optional[str] = field(metadata={"db_name": "attached_to"}) + version: str = field() + endpoint: str = field(metadata={"db_name": "url"}) + warmup: str = field() + auto_stop: int = field() + type: str = field() + provisioning: str = field() @property def database(self) -> Optional[Database]: - return self._service.resource_manager.bindings.get_database_bound_to_engine( - engine=self - ) - - def get_latest(self) -> Engine: - """Get an up-to-date instance of the engine from Firebolt.""" - return self._service.get(id_=self.engine_id) - - def attach_to_database( - self, database: Database, is_default_engine: bool = False - ) -> Binding: + if self._database_name: + try: + return self._service.resource_manager.databases.get(self._database_name) + except DatabaseNotFoundError: + pass + return None + + def refresh(self) -> None: + """Update attributes of the instance from Firebolt.""" + for name, value in self._service._get_dict(self.name).items(): + setattr(self, name, value) + + def attach_to_database(self, database_name: str) -> None: """ Attach this engine to a database. Args: database: Database to which the engine will be attached - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. """ - return self._service.resource_manager.bindings.create( - engine=self, database=database, is_default_engine=is_default_engine - ) + return self._service.attach_to_database(self.name, database_name) @check_attached_to_database def get_connection(self) -> Connection: @@ -193,7 +100,7 @@ def get_connection(self) -> Connection: """ return connect( - database=self.database.name, # type: ignore # already checked by decorator + database=self._database_name, # type: ignore # already checked by decorator # we always have firebolt Auth as a client auth auth=self._service.client.auth, # type: ignore engine_name=self.name, @@ -201,292 +108,127 @@ def get_connection(self) -> Connection: api_endpoint=self._service.resource_manager.api_endpoint, ) + def _wait_for_start_stop(self) -> None: + wait_timeout = 3600 + interval_seconds = 5 + timeout_time = time.time() + wait_timeout + while self.current_status in (EngineStatus.STOPPING, EngineStatus.STARTING): + logger.info( + f"Engine {self.name} is currently " + f"{self.current_status.lower()}, waiting" + ) + time.sleep(interval_seconds) + if time.time() > timeout_time: + raise TimeoutError( + f"Excedeed timeout of {wait_timeout}s waiting for " + f"an engine in {self.current_status.lower()} state" + ) + logger.info(".[!n]") + self.refresh() + @check_attached_to_database - def start( - self, - wait_for_startup: bool = True, - wait_timeout_seconds: int = 3600, - verbose: bool = False, - ) -> Engine: + def start(self) -> Engine: """ Start an engine. If it's already started, do nothing. - Args: - wait_for_startup: - If True, wait for startup to complete. - If False, return immediately after requesting startup. - wait_timeout_seconds: - Number of seconds to wait for startup to complete - before raising a TimeoutError - verbose: - If True, print dots periodically while waiting for engine start. - If False, do not print any dots. - Returns: - The updated engine from Firebolt. + The updated engine instance. """ - timeout_time = time.time() + wait_timeout_seconds - engine = self.get_latest() - if ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING - ): - logger.info( - f"Engine (engine_id={self.engine_id}, name={self.name}) " - "is already running." + self.refresh() + self._wait_for_start_stop() + if self.current_status == EngineStatus.RUNNING: + logger.info(f"Engine {self.name} is already running.") + return self + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to start engine {self.name} because it's " + f"in {self.current_status.lower()} state" ) - return engine - # wait for engine to stop first, if it's already stopping - # FUTURE: revisit logging and consider consolidating this if & the while below. - elif ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING - ): - logger.info( - f"Engine (engine_id={engine.engine_id}, name={engine.name}) " - "is in currently stopping, waiting for it to stop first." - ) - while ( - engine.current_status_summary - != EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED - ): - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( - "Engine " - f"(engine_id={engine.engine_id}, name={engine.name}) " - f"did not stop within {wait_timeout_seconds} seconds." - ), - verbose=True, - ) - engine = engine.get_latest() - - logger.info( - f"Engine (engine_id={engine.engine_id}, name={engine.name}) stopped." - ) - - engine = self._send_engine_request(ACCOUNT_ENGINE_START_URL) - logger.info( - f"Starting Engine (engine_id={engine.engine_id}, name={engine.name})" - ) - - # wait for engine to start - while wait_for_startup and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not start engine within {wait_timeout_seconds} seconds." - ), - verbose=verbose, - ) - previous_status_summary = engine.current_status_summary - engine = engine.get_latest() - if engine.current_status_summary != previous_status_summary: - logger.info( - "Engine status_summary=" - f"{getattr(engine.current_status_summary, 'name')}" - ) - - return engine + logger.info(f"Starting engine {self.name}") + with self._service._connection.cursor() as c: + c.execute(self.START_SQL.format(self.name)) + self.refresh() + return self @check_attached_to_database - def stop( - self, wait_for_stop: bool = False, wait_timeout_seconds: int = 3600 - ) -> Engine: - """Stop an Engine running on Firebolt.""" - timeout_time = time.time() + wait_timeout_seconds - - engine = self._send_engine_request(ACCOUNT_ENGINE_STOP_URL) - logger.info(f"Stopping Engine (engine_id={self.engine_id}, name={self.name})") - - while wait_for_stop and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not stop engine within {wait_timeout_seconds} seconds." - ), - verbose=False, - ) + def stop(self) -> Engine: + """Stop an engine. If it's already stopped, do nothing. - engine = engine.get_latest() - - return engine + Returns: + The updated engine instance. + """ + self.refresh() + self._wait_for_start_stop() + if self.current_status == EngineStatus.STOPPED: + logger.info(f"Engine {self.name} is already stopped.") + return self + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to stop engine {self.name} because it's " + f"in {self.current_status.lower()} state" + ) + logger.info(f"Stopping engine {self.name}") + with self._service._connection.cursor() as c: + c.execute(self.STOP_SQL.format(self.name)) + self.refresh() + return self def update( self, name: Optional[str] = None, - engine_type: Optional[EngineType] = None, scale: Optional[int] = None, spec: Optional[str] = None, auto_stop: Optional[int] = None, warmup: Optional[WarmupMethod] = None, - description: Optional[str] = None, - use_spot: Optional[bool] = None, ) -> Engine: """ Updates the engine and returns an updated version of the engine. If all parameters are set to None, old engine parameter values remain. """ - class _EngineUpdateRequest(FireboltBaseModel): - """Helper model for sending engine update requests.""" - - account_id: str - desired_revision: Optional[EngineRevision] - engine: Engine - engine_id: str - update_mask: FieldMask - - # Update the engine parameters - self.name = name if name else self.name - self.description = description - - # Update engine settings - engine_settings_params = { - "engine_type": engine_type, - "auto_stop_delay_duration": f"{auto_stop * 60}s" - if auto_stop is not None - else None, - "warm_up": warmup, - } - self.settings = EngineSettings.default(**prune_dict(engine_settings_params)) - - # Update the engine desired_revision if needed - desired_revision = None - if (scale or spec or use_spot is not None) and self.latest_revision_key: - rm = self._service.resource_manager - desired_revision = rm.engine_revisions.get_by_key(self.latest_revision_key) - - if spec: - instance_type_key = rm.instance_types.get_by_name( - instance_type_name=spec, - region_name=rm.regions.regions_by_key[self.compute_region_key].name, - ).key - - desired_revision.specification.db_compute_instances_type_key = ( - instance_type_key - ) - desired_revision.specification.proxy_instances_type_key = ( - instance_type_key - ) + if not any((name, scale, spec, auto_stop, warmup)): + # Nothing to be updated + return self - if scale: - desired_revision.specification.db_compute_instances_count = scale - - if use_spot is not None: - desired_revision.specification.db_compute_instances_use_spot = use_spot - - update_mask_paths = list( - prune_dict( - { - "name": name, - "description": description, - "settings.auto_stop_delay_duration": auto_stop, - "settings.warm_up": warmup, - "settings.is_read_only": engine_type, - "settings.preset": engine_type, - } - ).keys() - ) + self._wait_for_start_stop() + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to update engine {self.name} because it's " + f"in {self.current_status.lower()} state" + ) - # Send the update request - response = self._service.client.patch( - url=ACCOUNT_ENGINE_URL.format( - engine_id=self.engine_id, account_id=self._service.account_id - ), - headers={"Content-type": "application/json"}, - json=_EngineUpdateRequest( - account_id=self._service.account_id, - engine=self, - desired_revision=desired_revision, - engine_id=self.engine_id, - update_mask=FieldMask(paths=update_mask_paths), - ).jsonable_dict(by_alias=True), - ) + sql = self.ALTER_PREFIX_SQL.format(self.name) + parameters = [] + for name, value in zip( + self.ALTER_PARAMETER_NAMES, (scale, spec, auto_stop, name, warmup) + ): + if value: + sql += f"{name} = ? " + parameters.append(value) - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) + with self._service._connection.cursor() as c: + c.execute(sql, parameters) + self.refresh() + return self @check_attached_to_database - def restart( - self, - wait_for_startup: bool = True, - wait_timeout_seconds: int = 3600, - ) -> Engine: + def restart(self) -> Engine: """ Restart an engine. - Args: - wait_for_startup: - If True, wait for startup to complete. - If False, return immediately after requesting startup. - wait_timeout_seconds: - Number of seconds to wait for startup to complete - before raising a TimeoutError. - Returns: The updated engine from Firebolt. """ - timeout_time = time.time() + wait_timeout_seconds - - engine = self._send_engine_request(ACCOUNT_ENGINE_RESTART_URL) - logger.info(f"Stopping Engine (engine_id={self.engine_id}, name={self.name})") - - while wait_for_startup and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not restart engine within {wait_timeout_seconds} seconds." - ), - verbose=False, - ) - - engine = engine.get_latest() - - return engine - - def delete(self) -> Engine: - """Delete an engine from Firebolt.""" - response = self._service.client.delete( - url=ACCOUNT_ENGINE_URL.format( - account_id=self._service.account_id, engine_id=self.engine_id - ), - ) - logger.info(f"Deleting Engine (engine_id={self.engine_id}, name={self.name})") - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) - - def _send_engine_request(self, url: str) -> Engine: - response = self._service.client.post( - url=url.format( - account_id=self._service.account_id, engine_id=self.engine_id - ) - ) - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) - - -class _EngineCreateRequest(FireboltBaseModel): - """Helper model for sending engine create requests.""" - - account_id: str - engine: Engine - engine_revision: Optional[EngineRevision] + self.stop() + self.start() + return self + + def delete(self) -> None: + """Delete an engine.""" + self.refresh() + if self.current_status == EngineStatus.DROPPING: + return + with self._service._connection.cursor() as c: + c.execute(self.DROP_SQL.format(self.name)) diff --git a/src/firebolt/model/engine_revision.py b/src/firebolt/model/engine_revision.py deleted file mode 100644 index d3257f230c3..00000000000 --- a/src/firebolt/model/engine_revision.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic import Field, PositiveInt - -from firebolt.model import FireboltBaseModel -from firebolt.model.instance_type import InstanceTypeKey - - -class EngineRevisionKey(FireboltBaseModel): - account_id: str - engine_id: str - engine_revision_id: str - - -class EngineRevisionSpecification(FireboltBaseModel): - """ - An EngineRevision specification. - - Determines which instance types and how many of them its engine gets. - - See Also: :py:class:`Settings - `, - which also contains engine configuration. - """ - - db_compute_instances_type_key: InstanceTypeKey = Field( - alias="db_compute_instances_type_id" - ) - db_compute_instances_count: PositiveInt - db_compute_instances_use_spot: bool = False - db_version: str = "" - proxy_instances_type_key: InstanceTypeKey = Field(alias="proxy_instances_type_id") - proxy_instances_count: PositiveInt = 1 - proxy_version: str = "" - - -class EngineRevision(FireboltBaseModel): - """ - A Firebolt engine revision, - which contains a specification (instance types, counts). - - As engines are updated with new settings, revisions are created. - """ - - specification: EngineRevisionSpecification - - # optional - key: Optional[EngineRevisionKey] = Field(None, alias="id") - current_status: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - health_status: Optional[str] diff --git a/src/firebolt/model/instance_type.py b/src/firebolt/model/instance_type.py index 24decd63006..5ead1b30477 100644 --- a/src/firebolt/model/instance_type.py +++ b/src/firebolt/model/instance_type.py @@ -6,7 +6,7 @@ from firebolt.model import FireboltBaseModel -class InstanceTypeKey(FireboltBaseModel, frozen=True): # type: ignore +class InstanceTypeKey(FireboltBaseModel): # type: ignore provider_id: str region_id: str instance_type_id: str diff --git a/src/firebolt/model/provider.py b/src/firebolt/model/provider.py index 242856ef218..3cc041c2710 100644 --- a/src/firebolt/model/provider.py +++ b/src/firebolt/model/provider.py @@ -1,16 +1,16 @@ +from dataclasses import dataclass, field from datetime import datetime from typing import Optional -from pydantic import Field - from firebolt.model import FireboltBaseModel -class Provider(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str = Field(alias="id") - name: str +@dataclass +class Provider(FireboltBaseModel): + provider_id: str = field(metadata={"db_name": "id"}) + name: str = field() # optional - create_time: Optional[datetime] - display_name: Optional[str] - last_update_time: Optional[datetime] + create_time: Optional[datetime] = field(default=None) + display_name: Optional[str] = field(default=None) + last_update_time: Optional[datetime] = field(default=None) diff --git a/src/firebolt/model/region.py b/src/firebolt/model/region.py index c3b476da26c..b292180ce98 100644 --- a/src/firebolt/model/region.py +++ b/src/firebolt/model/region.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -6,7 +7,8 @@ from firebolt.model import FireboltBaseModel -class RegionKey(FireboltBaseModel, frozen=True): # type: ignore +@dataclass +class RegionKey: provider_id: str region_id: str diff --git a/src/firebolt/service/base.py b/src/firebolt/service/base.py index 079a2a18f5d..e677098c8ec 100644 --- a/src/firebolt/service/base.py +++ b/src/firebolt/service/base.py @@ -1,9 +1,14 @@ +from typing import TYPE_CHECKING + from firebolt.client import Client -from firebolt.service.manager import ResourceManager +from firebolt.db import Connection + +if TYPE_CHECKING: + from firebolt.service.manager import ResourceManager class BaseService: - def __init__(self, resource_manager: ResourceManager): + def __init__(self, resource_manager: "ResourceManager"): self.resource_manager = resource_manager @property @@ -17,3 +22,7 @@ def account_id(self) -> str: @property def _default_region(self) -> str: return self.resource_manager.default_region + + @property + def _connection(self) -> Connection: + return self.resource_manager._connection diff --git a/src/firebolt/service/binding.py b/src/firebolt/service/binding.py deleted file mode 100644 index a12aff81e4e..00000000000 --- a/src/firebolt/service/binding.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -from typing import List, Optional - -from firebolt.model.binding import Binding, BindingKey -from firebolt.model.database import Database -from firebolt.model.engine import Engine -from firebolt.service.base import BaseService -from firebolt.utils.exception import AlreadyBoundError -from firebolt.utils.urls import ( - ACCOUNT_BINDINGS_URL, - ACCOUNT_DATABASE_BINDING_URL, -) -from firebolt.utils.util import prune_dict - -logger = logging.getLogger(__name__) - - -class BindingService(BaseService): - def get_by_key(self, binding_key: BindingKey) -> Binding: - """Get a binding by its BindingKey""" - response = self.client.get( - url=ACCOUNT_DATABASE_BINDING_URL.format( - account_id=binding_key.account_id, - database_id=binding_key.database_id, - engine_id=binding_key.engine_id, - ) - ) - binding: dict = response.json()["binding"] - return Binding.parse_obj(binding) - - def get_many( - self, - database_id: Optional[str] = None, - engine_id: Optional[str] = None, - is_system_database: Optional[bool] = None, - ) -> List[Binding]: - """ - List bindings on Firebolt, optionally filtering by database and engine. - - Args: - database_id: - Return bindings matching the database_id. - If None, match any databases. - engine_id: - Return bindings matching the engine_id. - If None, match any engines. - is_system_database: - If True, return only system databases. - If False, return only non-system databases. - If None, do not filter on this parameter. - - Returns: - List of bindings matching the filter parameters - """ - - response = self.client.get( - url=ACCOUNT_BINDINGS_URL.format(account_id=self.account_id), - params=prune_dict( - { - "page.first": 5000, # FUTURE: pagination support w/ generator - "filter.id_database_id_eq": database_id, - "filter.id_engine_id_eq": engine_id, - "filter.is_system_database_eq": is_system_database, - } - ), - ) - return [Binding.parse_obj(i["node"]) for i in response.json()["edges"]] - - def get_database_bound_to_engine(self, engine: Engine) -> Optional[Database]: - """Get the database to which an engine is bound, if any.""" - try: - binding = self.get_many(engine_id=engine.engine_id)[0] - except IndexError: - return None - try: - return self.resource_manager.databases.get(id_=binding.database_id) - except (KeyError, IndexError): - return None - - def get_engines_bound_to_database(self, database: Database) -> List[Engine]: - """Get a list of engines that are bound to a database.""" - - bindings = self.get_many(database_id=database.database_id) - if not bindings: - return [] - return self.resource_manager.engines.get_by_ids( - ids=[b.engine_id for b in bindings] - ) - - def create( - self, engine: Engine, database: Database, is_default_engine: bool - ) -> Binding: - """ - Create a new binding between an engine and a database. - - Args: - engine: Engine to bind. - database: Database to bind. - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. - - Returns: - New binding between the engine and database. - """ - - existing_database = self.get_database_bound_to_engine(engine=engine) - if existing_database is not None: - raise AlreadyBoundError( - f"The engine {engine.name} is already bound " - f"to {existing_database.name}!" - ) - - logger.info( - f"Attaching Engine (engine_id={engine.engine_id}, name={engine.name}) " - f"to Database (database_id={database.database_id}, " - f"name={database.name})" - ) - binding = Binding( - binding_key=BindingKey( - account_id=self.account_id, - database_id=database.database_id, - engine_id=engine.engine_id, - ), - is_default_engine=is_default_engine, - ) - - response = self.client.post( - url=ACCOUNT_DATABASE_BINDING_URL.format( - account_id=self.account_id, - database_id=database.database_id, - engine_id=engine.engine_id, - ), - json=binding.jsonable_dict( - by_alias=True, include={"binding_key": ..., "is_default_engine": ...} - ), - ) - return Binding.parse_obj(response.json()["binding"]) diff --git a/src/firebolt/service/database.py b/src/firebolt/service/database.py index 71d58450017..fd64284bb50 100644 --- a/src/firebolt/service/database.py +++ b/src/firebolt/service/database.py @@ -1,7 +1,6 @@ import logging from typing import List, Optional, Union -from firebolt.model import FireboltBaseModel from firebolt.model.database import Database from firebolt.service.base import BaseService from firebolt.service.types import DatabaseOrder @@ -22,9 +21,7 @@ def get(self, id_: str) -> Database: response = self.client.get( url=ACCOUNT_DATABASE_URL.format(account_id=self.account_id, database_id=id_) ) - return Database.parse_obj_with_service( - obj=response.json()["database"], database_service=self - ) + return Database._from_dict(response.json()["database"], self) def get_by_name(self, name: str) -> Database: """Get a database from Firebolt by its name.""" @@ -80,10 +77,7 @@ def get_many( params=prune_dict(params), ) - return [ - Database.parse_obj_with_service(obj=d["node"], database_service=self) - for d in response.json()["edges"] - ] + return [Database._from_dict(d["node"], self) for d in response.json()["edges"]] def create( self, name: str, region: Optional[str] = None, description: Optional[str] = None @@ -99,29 +93,10 @@ def create( The newly created database """ - class _DatabaseCreateRequest(FireboltBaseModel): - """Helper model for sending database creation requests.""" - - account_id: str - database: Database - - if region is None: - region_key = self.resource_manager.regions.default_region.key - else: - region_key = self.resource_manager.regions.get_by_name(name=region).key - database = Database( - name=name, compute_region_key=region_key, description=description - ) - logger.info(f"Creating Database (name={name})") response = self.client.post( url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), headers={"Content-type": "application/json"}, - json=_DatabaseCreateRequest( - account_id=self.account_id, - database=database, - ).jsonable_dict(by_alias=True), - ) - return Database.parse_obj_with_service( - obj=response.json()["database"], database_service=self + json={}, ) + return Database._from_dict(response.json()["database"], self) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index dbee674e88a..768ec1a856a 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -1,69 +1,67 @@ from logging import getLogger -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union -from firebolt.model.engine import Engine, EngineSettings, _EngineCreateRequest -from firebolt.model.engine_revision import ( - EngineRevision, - EngineRevisionSpecification, -) +from firebolt.model.engine import Engine from firebolt.model.region import Region from firebolt.service.base import BaseService -from firebolt.service.types import EngineOrder, EngineType, WarmupMethod -from firebolt.utils.exception import FireboltError -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_LIST_ENGINES_URL, - ENGINES_BY_IDS_URL, -) -from firebolt.utils.util import prune_dict +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.exception import EngineNotFoundError logger = getLogger(__name__) class EngineService(BaseService): - def get(self, id_: str) -> Engine: - """Get an engine from Firebolt by its ID.""" - - response = self.client.get( - url=ACCOUNT_ENGINE_URL.format(account_id=self.account_id, engine_id=id_), - ) - engine_entry: dict = response.json()["engine"] - return Engine.parse_obj_with_service(obj=engine_entry, engine_service=self) - - def get_by_ids(self, ids: List[str]) -> List[Engine]: - """Get multiple engines from Firebolt by ID.""" - response = self.client.post( - url=ENGINES_BY_IDS_URL, - json={ - "engine_ids": [ - {"account_id": self.account_id, "engine_id": engine_id} - for engine_id in ids - ] - }, - ) - return [ - Engine.parse_obj_with_service(obj=e, engine_service=self) - for e in response.json()["engines"] - ] - - def get_by_name(self, name: str) -> Engine: + ENGINE_DB_FIELDS = ( + "engine_name", + "region", + "spec", + "scale", + "status", + "attached_to", + "version", + "url", + "warmup", + "auto_stop", + "type", + "provisioning", + ) + GET_SQL = f"SELECT {', '.join(ENGINE_DB_FIELDS)} FROM information_schema.engines" + GET_BY_NAME_SQL = GET_SQL + " WHERE engine_name=?" + GET_WHERE_SQL = " WHERE " + + CREATE_PREFIX_SQL = "CREATE ENGINE {}" + CREATE_WITH_SQL = " WITH " + CREATE_PARAMETER_NAMES = ( + "REGION", + "ENGINE_TYPE", + "SPEC", + "SCALE", + "AUTO_STOP", + "WARMUP", + ) + + ATTACH_TO_DB_SQL = "ATTACH ENGINE {} TO {}" + + def _get_dict(self, name: str) -> dict: + with self._connection.cursor() as c: + count = c.execute(self.GET_BY_NAME_SQL, (name,)) + if count == 0: + raise EngineNotFoundError(name) + return { + column.name: value for column, value in zip(c.description, c.fetchone()) + } + + def get(self, name: str) -> Engine: """Get an engine from Firebolt by its name.""" - - response = self.client.get( - url=ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=self.account_id), - params={"engine_name": name}, - ) - engine_id = response.json()["engine_id"]["engine_id"] - return self.get(id_=engine_id) + engine_dict = self._get_dict(name) + return Engine._from_dict(engine_dict, self) def get_many( self, name_contains: Optional[str] = None, - current_status_eq: Optional[str] = None, - current_status_not_eq: Optional[str] = None, - region_eq: Optional[str] = None, - order_by: Optional[Union[str, EngineOrder]] = None, + current_status_eq: Union[str, EngineStatus, None] = None, + current_status_not_eq: Union[str, EngineStatus, None] = None, + region_eq: Union[str, Region, None] = None, ) -> List[Engine]: """ Get a list of engines on Firebolt. @@ -78,45 +76,43 @@ def get_many( Returns: A list of engines matching the filters """ - - if isinstance(order_by, str): - order_by = EngineOrder[order_by].name - - if region_eq is not None: - region_eq = self.resource_manager.regions.get_by_name( - name=region_eq - ).key.region_id - - response = self.client.get( - url=ACCOUNT_LIST_ENGINES_URL.format(account_id=self.account_id), - params=prune_dict( - { - "page.first": 5000, # FUTURE: pagination support w/ generator - "filter.name_contains": name_contains, - "filter.current_status_eq": current_status_eq, - "filter.current_status_not_eq": current_status_not_eq, - "filter.compute_region_id_region_id_eq": region_eq, - "order_by": order_by, - } - ), - ) - return [ - Engine.parse_obj_with_service(obj=e["node"], engine_service=self) - for e in response.json()["edges"] - ] + sql = self.GET_SQL + parameters = [] + if any((name_contains, current_status_eq, current_status_not_eq, region_eq)): + condition = [] + if name_contains: + condition.append("engine_name like ?") + parameters.append(f"%{name_contains}%") + if current_status_eq: + condition.append("status = ?") + parameters.append(str(current_status_eq)) + if current_status_not_eq: + condition.append("status != ?") + parameters.append(str(current_status_eq)) + if region_eq: + condition.append("region = ?") + parameters.append(str(region_eq)) + sql += self.GET_WHERE_SQL + " AND ".join(condition) + + with self._connection.cursor() as c: + c.execute(sql, parameters) + engine_dicts = [ + {column.name: value for column, value in zip(c.description, engine_row)} + for engine_row in c.fetchall() + ] + return [ + Engine._from_dict(engine_dict, self) for engine_dict in engine_dicts + ] def create( self, name: str, region: Union[str, Region, None] = None, engine_type: Union[str, EngineType] = EngineType.GENERAL_PURPOSE, - scale: int = 2, spec: Optional[str] = None, - auto_stop: int = 20, - warmup: Union[str, WarmupMethod] = WarmupMethod.PRELOAD_INDEXES, - description: str = "", - engine_settings_kwargs: Dict[str, Any] = {}, - revision_spec_kwargs: Dict[str, Any] = {}, + scale: Optional[int] = None, + auto_stop: Optional[int] = None, + warmup: Union[str, WarmupMethod, None] = None, ) -> Engine: """ Create a new engine. @@ -125,10 +121,10 @@ def create( name: An identifier that specifies the name of the engine region: The AWS region in which the engine runs engine_type: The engine type. GENERAL_PURPOSE or DATA_ANALYTICS - scale: The number of compute instances on the engine. - The scale can be any int from 1 to 128. spec: Firebolt instance type. If not set, will default to the cheapest instance. + scale: The number of compute instances on the engine. + The scale can be any int from 1 to 128. auto_stop: The amount of time (in minutes) after which the engine automatically stops warmup: The warmup method that should be used: @@ -139,85 +135,28 @@ def create( `PRELOAD_ALL_DATA` - Full data auto-load (both indexes and table data - full warmup) - description: A short description of the engine's purpose Returns: Engine with the specified settings """ - logger.info(f"Creating Engine (name={name})") - - if isinstance(engine_type, str): - engine_type = EngineType[engine_type] - if isinstance(warmup, str): - warmup = WarmupMethod[warmup] - - if region is None: - region = self.resource_manager.regions.default_region - else: - if isinstance(region, str): - region = self.resource_manager.regions.get_by_name(name=region) - - engine = Engine( - name=name, - description=description, - compute_region_key=region.key, - settings=EngineSettings.default( - engine_type=engine_type, - auto_stop_delay_duration=f"{auto_stop * 60}s", - warm_up=warmup, - **engine_settings_kwargs, - ), - ) - - if spec: - instance_type_key = self.resource_manager.instance_types.get_by_name( - instance_type_name=spec, region_name=region.name - ).key - else: - instance_type = ( - self.resource_manager.instance_types.cheapest_instance_in_region(region) - ) - if not instance_type: - raise FireboltError( - f"No suitable default instances found in region {region}" - ) - instance_type_key = instance_type.key - - engine_revision = EngineRevision( - specification=EngineRevisionSpecification( - db_compute_instances_type_key=instance_type_key, - db_compute_instances_count=scale, - proxy_instances_type_key=instance_type_key, - **revision_spec_kwargs, - ) - ) - - return self._send_create_engine(engine=engine, engine_revision=engine_revision) - - def _send_create_engine( - self, engine: Engine, engine_revision: Optional[EngineRevision] = None - ) -> Engine: - """ - Create a new Engine on Firebolt from the local Engine object. - - Args: - engine: The engine to create - engine_revision: EngineRevision to use for configuring the engine - - Returns: - The newly created engine - """ - - response = self.client.post( - url=ACCOUNT_LIST_ENGINES_URL.format(account_id=self.account_id), - headers={"Content-type": "application/json"}, - json=_EngineCreateRequest( - account_id=self.account_id, - engine=engine, - engine_revision=engine_revision, - ).jsonable_dict(by_alias=True), - ) - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self - ) + logger.info(f"Creating engine {name}") + + sql = self.CREATE_PREFIX_SQL.format(name) + parameters = [] + if any((region, engine_type, spec, scale, auto_stop, warmup)): + sql += self.CREATE_WITH_SQL + for name, value in zip( + self.CREATE_PARAMETER_NAMES, + (region, engine_type, spec, scale, auto_stop, warmup), + ): + if value: + sql += f"{name} = ? " + parameters.append(value) + with self._connection.cursor() as c: + c.execute(sql, parameters) + return self.get(name) + + def attach_to_database(self, engine_name: str, database_name: str) -> None: + with self._connection.cursor() as c: + c.execute(self.ATTACH_TO_DB_SQL.format(engine_name, database_name)) diff --git a/src/firebolt/service/engine_revision.py b/src/firebolt/service/engine_revision.py deleted file mode 100644 index b79964d62b1..00000000000 --- a/src/firebolt/service/engine_revision.py +++ /dev/null @@ -1,39 +0,0 @@ -from firebolt.model.engine_revision import EngineRevision, EngineRevisionKey -from firebolt.service.base import BaseService -from firebolt.utils.urls import ACCOUNT_ENGINE_REVISION_URL - - -class EngineRevisionService(BaseService): - def get_by_id(self, engine_id: str, engine_revision_id: str) -> EngineRevision: - """ - Get an EngineRevision from Firebolt by engine_id and engine_revision_id. - """ - - return self.get_by_key( - EngineRevisionKey( - account_id=self.account_id, - engine_id=engine_id, - engine_revision_id=engine_revision_id, - ) - ) - - def get_by_key(self, key: EngineRevisionKey) -> EngineRevision: - """ - Fetch an EngineRevision from Firebolt by its key. - - Args: - key: Key of the desired EngineRevision - - Returns: - The requested EngineRevision - """ - - response = self.client.get( - url=ACCOUNT_ENGINE_REVISION_URL.format( - account_id=key.account_id, - engine_id=key.engine_id, - revision_id=key.engine_revision_id, - ), - ) - engine_spec: dict = response.json()["engine_revision"] - return EngineRevision.parse_obj(engine_spec) diff --git a/src/firebolt/service/instance_type.py b/src/firebolt/service/instance_type.py index 47a3ea26d45..6bc779ef60f 100644 --- a/src/firebolt/service/instance_type.py +++ b/src/firebolt/service/instance_type.py @@ -23,7 +23,7 @@ def instance_types(self) -> List[InstanceType]: url=ACCOUNT_INSTANCE_TYPES_URL.format(account_id=self.account_id), params={"page.first": 5000}, ) - return [InstanceType.parse_obj(i["node"]) for i in response.json()["edges"]] + return [InstanceType._from_dict(i["node"]) for i in response.json()["edges"]] @cached_property def instance_types_by_key(self) -> Dict[InstanceTypeKey, InstanceType]: @@ -54,7 +54,7 @@ def get_instance_types_per_region(self, region: Region) -> List[InstanceType]: ) instance_list = [ - InstanceType.parse_obj(i["node"]) for i in response.json()["edges"] + InstanceType._from_dict(i["node"]) for i in response.json()["edges"] ] # Filter out instances without storage diff --git a/src/firebolt/service/provider.py b/src/firebolt/service/provider.py index 9dcba914cc1..129c8ef896e 100644 --- a/src/firebolt/service/provider.py +++ b/src/firebolt/service/provider.py @@ -6,5 +6,5 @@ def get_provider_id(client: Client) -> str: """Get the AWS provider_id.""" response = client.get(url=PROVIDERS_URL) - providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]] + providers = [Provider._from_dict(i["node"]) for i in response.json()["edges"]] return providers[0].provider_id diff --git a/src/firebolt/service/region.py b/src/firebolt/service/region.py index d79f73995c3..2e9f99a63ca 100644 --- a/src/firebolt/service/region.py +++ b/src/firebolt/service/region.py @@ -23,7 +23,7 @@ def regions(self) -> List[Region]: """List of available AWS regions on Firebolt.""" response = self.client.get(url=REGIONS_URL, params={"page.first": 5000}) - return [Region.parse_obj(i["node"]) for i in response.json()["edges"]] + return [Region._from_dict(i["node"], self) for i in response.json()["edges"]] @cached_property def regions_by_name(self) -> Dict[str, Region]: diff --git a/src/firebolt/service/types.py b/src/firebolt/service/types.py index 0ccdaef3573..d72cbd183c0 100644 --- a/src/firebolt/service/types.py +++ b/src/firebolt/service/types.py @@ -1,17 +1,16 @@ from enum import Enum -from types import DynamicClassAttribute class EngineType(Enum): GENERAL_PURPOSE = "GENERAL_PURPOSE" DATA_ANALYTICS = "DATA_ANALYTICS" - @DynamicClassAttribute - def api_settings_preset_name(self) -> str: + @classmethod + def from_display_name(cls, display_name: str) -> "EngineType": return { - EngineType.GENERAL_PURPOSE: "ENGINE_SETTINGS_PRESET_GENERAL_PURPOSE", - EngineType.DATA_ANALYTICS: "ENGINE_SETTINGS_PRESET_DATA_ANALYTICS", - }[self] + "General Purpose": cls.GENERAL_PURPOSE, + "Analytics": cls.DATA_ANALYTICS, + }[display_name] class WarmupMethod(Enum): @@ -19,205 +18,29 @@ class WarmupMethod(Enum): PRELOAD_INDEXES = "PRELOAD_INDEXES" PRELOAD_ALL_DATA = "PRELOAD_ALL_DATA" - @DynamicClassAttribute - def api_name(self) -> str: + @classmethod + def from_display_name(cls, display_name: str) -> "WarmupMethod": return { - WarmupMethod.MINIMAL: "ENGINE_SETTINGS_WARM_UP_MINIMAL", - WarmupMethod.PRELOAD_INDEXES: "ENGINE_SETTINGS_WARM_UP_INDEXES", - WarmupMethod.PRELOAD_ALL_DATA: "ENGINE_SETTINGS_WARM_UP_ALL", - }[self] + "Minimal": cls.MINIMAL, + "Indexes": cls.PRELOAD_INDEXES, + "All": cls.PRELOAD_ALL_DATA, + }[display_name] class EngineStatus(Enum): """ Detailed engine status. - See: https://api.dev.firebolt.io/devDocs#operation/coreV1GetEngine - """ - - ENGINE_STATUS_UNSPECIFIED = "ENGINE_STATUS_UNSPECIFIED" - """ Logical record is created, however, underlying infrastructure - is not initialized. - In other words, this means that engine is stopped.""" - - ENGINE_STATUS_CREATED = "ENGINE_STATUS_CREATED" - """Engine status was created.""" - - ENGINE_STATUS_PROVISIONING_PENDING = "ENGINE_STATUS_PROVISIONING_PENDING" - """ Engine initialization request was sent.""" - - ENGINE_STATUS_PROVISIONING_STARTED = "ENGINE_STATUS_PROVISIONING_STARTED" - """ Engine initialization request was received - and initialization process started.""" - - ENGINE_STATUS_PROVISIONING_FINISHED = "ENGINE_STATUS_PROVISIONING_FINISHED" - """ Engine initialization was finished successfully.""" - - ENGINE_STATUS_PROVISIONING_FAILED = "ENGINE_STATUS_PROVISIONING_FAILED" - """ Engine initialization failed due to error.""" - - ENGINE_STATUS_RUNNING_IDLE = "ENGINE_STATUS_RUNNING_IDLE" - """ Engine is initialized, - but there are no running or starting engine revisions.""" - - ENGINE_STATUS_RUNNING_REVISION_STARTING = "ENGINE_STATUS_RUNNING_REVISION_STARTING" - """ Engine is initialized, - there are no running engine revisions, but it's starting.""" - - ENGINE_STATUS_RUNNING_REVISION_STARTUP_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_STARTUP_FAILED" - ) - """ Engine is initialized; - initial revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISION_SERVING = "ENGINE_STATUS_RUNNING_REVISION_SERVING" - """ Engine is ready (serves an engine revision). """ - - ENGINE_STATUS_RUNNING_REVISION_CHANGING = "ENGINE_STATUS_RUNNING_REVISION_CHANGING" - """ Engine is ready (serves an engine revision); - zero-downtime replacement revision is starting.""" - - ENGINE_STATUS_RUNNING_REVISION_CHANGE_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_CHANGE_FAILED" - ) - """ Engine is ready (serves an engine revision); - replacement revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISION_RESTARTING = ( - "ENGINE_STATUS_RUNNING_REVISION_RESTARTING" - ) - """ Engine is initialized; - replacement of the revision with a downtime is in progress.""" - - ENGINE_STATUS_RUNNING_REVISION_RESTART_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_RESTART_FAILED" - ) - """ Engine is initialized; - replacement revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISIONS_TERMINATING = ( - "ENGINE_STATUS_RUNNING_REVISIONS_TERMINATING" - ) - """ Engine is initialized; - all child revisions are being terminated.""" - - # Engine termination request was sent. - ENGINE_STATUS_TERMINATION_PENDING = "ENGINE_STATUS_TERMINATION_PENDING" - """ Engine termination request was sent.""" - - ENGINE_STATUS_TERMINATION_ST = "ENGINE_STATUS_TERMINATION_STARTED" - """ Engine termination started.""" - - ENGINE_STATUS_TERMINATION_FIN = "ENGINE_STATUS_TERMINATION_FINISHED" - """ Engine termination finished.""" - - ENGINE_STATUS_TERMINATION_F = "ENGINE_STATUS_TERMINATION_FAILED" - """ Engine termination failed.""" - - ENGINE_STATUS_DELETED = "ENGINE_STATUS_DELETED" - """ Engine is soft-deleted.""" - + See: https://docs.firebolt.io/working-with-engines/understanding-engine-fundamentals.html + """ # noqa -class EngineStatusSummary(Enum): - """ - Engine summary status. - - See: https://api.dev.firebolt.io/devDocs#operation/coreV1GetEngine - """ - - ENGINE_STATUS_SUMMARY_UNSPECIFIED = "ENGINE_STATUS_SUMMARY_UNSPECIFIED" - """Status unspecified""" - - ENGINE_STATUS_SUMMARY_STOPPED = "ENGINE_STATUS_SUMMARY_STOPPED" - """ Fully stopped.""" - - ENGINE_STATUS_SUMMARY_STARTING = "ENGINE_STATUS_SUMMARY_STARTING" - """ Provisioning process is in progress; - creating cloud infra for this engine.""" - - ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING = ( - "ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING" - ) - """ Provisioning process is complete; - waiting for PackDB cluster to initialize and start.""" - - ENGINE_STATUS_SUMMARY_RUNNING = "ENGINE_STATUS_SUMMARY_RUNNING" - """ Fully started; - engine is ready to serve requests.""" - - ENGINE_STATUS_SUMMARY_UPGRADING = "ENGINE_STATUS_SUMMARY_UPGRADING" - """ Version of the PackDB is changing. - This is zero downtime operation that does not affect engine work. - This status is reserved for future use (not used fow now).""" - - ENGINE_STATUS_SUMMARY_RESTARTING = "ENGINE_STATUS_SUMMARY_RESTARTING" - """ Hard restart (full stop/start cycle) is in progress. - Underlying infrastructure is being recreated.""" - - ENGINE_STATUS_SUMMARY_RESTARTING_INITIALIZING = ( - "ENGINE_STATUS_SUMMARY_RESTARTING_INITIALIZING" - ) - """ Hard restart (full stop/start cycle) is in progress. - Underlying infrastructure is ready. Waiting for - PackDB cluster to initialize and start. - This status is logically the same as ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING, - but used during restart cycle.""" - - ENGINE_STATUS_SUMMARY_REPAIRING = "ENGINE_STATUS_SUMMARY_REPAIRING" - """ Underlying infrastructure has issues and is being repaired. - Engine is still running, but it's not fully healthy and some queries may fail.""" - - ENGINE_STATUS_SUMMARY_STOPPING = "ENGINE_STATUS_SUMMARY_STOPPING" - """ Stop is in progress.""" - - ENGINE_STATUS_SUMMARY_DELETING = "ENGINE_STATUS_SUMMARY_DELETING" - """ Termination is in progress. - All infrastructure that belongs to this engine will be completely destroyed.""" - - ENGINE_STATUS_SUMMARY_DELETED = "ENGINE_STATUS_SUMMARY_DELETED" - """ Infrastructure is terminated, engine data is deleted.""" - - ENGINE_STATUS_SUMMARY_FAILED = "ENGINE_STATUS_SUMMARY_FAILED" - """ Failed to start or stop. - This status only indicates that there were issues during provisioning operations. - If engine enters this status, - all infrastructure should be stopped/terminated already.""" - - -class EngineOrder(Enum): - ENGINE_ORDER_UNSPECIFIED = "ENGINE_ORDER_UNSPECIFIED" - ENGINE_ORDER_NAME_ASC = "ENGINE_ORDER_NAME_ASC" - ENGINE_ORDER_NAME_DESC = "ENGINE_ORDER_NAME_DESC" - ENGINE_ORDER_COMPUTE_REGION_ID_ASC = "ENGINE_ORDER_COMPUTE_REGION_ID_ASC" - ENGINE_ORDER_COMPUTE_REGION_ID_DESC = "ENGINE_ORDER_COMPUTE_REGION_ID_DESC" - ENGINE_ORDER_CURRENT_STATUS_ASC = "ENGINE_ORDER_CURRENT_STATUS_ASC" - ENGINE_ORDER_CURRENT_STATUS_DESC = "ENGINE_ORDER_CURRENT_STATUS_DESC" - ENGINE_ORDER_CREATE_TIME_ASC = "ENGINE_ORDER_CREATE_TIME_ASC" - ENGINE_ORDER_CREATE_TIME_DESC = "ENGINE_ORDER_CREATE_TIME_DESC" - ENGINE_ORDER_CREATE_ACTOR_ASC = "ENGINE_ORDER_CREATE_ACTOR_ASC" - ENGINE_ORDER_CREATE_ACTOR_DESC = "ENGINE_ORDER_CREATE_ACTOR_DESC" - ENGINE_ORDER_LAST_UPDATE_TIME_ASC = "ENGINE_ORDER_LAST_UPDATE_TIME_ASC" - ENGINE_ORDER_LAST_UPDATE_TIME_DESC = "ENGINE_ORDER_LAST_UPDATE_TIME_DESC" - ENGINE_ORDER_LAST_UPDATE_ACTOR_ASC = "ENGINE_ORDER_LAST_UPDATE_ACTOR_ASC" - ENGINE_ORDER_LAST_UPDATE_ACTOR_DESC = "ENGINE_ORDER_LAST_UPDATE_ACTOR_DESC" - ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_DESC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_DESC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_DESC" - ) + STARTING = "Starting" + STARTED = "Started" + RUNNING = "Running" + STOPPING = "Stopping" + STOPPED = "Stopped" + DROPPING = "Dropping" + REPAIRING = "Repairing" class DatabaseOrder(Enum): diff --git a/src/firebolt/utils/exception.py b/src/firebolt/utils/exception.py index a5e87212ecb..28afda1926f 100644 --- a/src/firebolt/utils/exception.py +++ b/src/firebolt/utils/exception.py @@ -13,6 +13,20 @@ def __init__(self, engine_name: str): super().__init__(f"Engine {engine_name} is not running") +class EngineNotFoundError(FireboltEngineError): + """Engine with provided name was not found.""" + + def __init__(self, engine_name: str): + super().__init__(f"Engine with name {engine_name} was not found") + + +class DatabaseNotFoundError(FireboltError): + """Database with provided name was not found.""" + + def __init__(self, database_name: str): + super().__init__(f"Database with name {database_name} was not found") + + class NoAttachedDatabaseError(FireboltEngineError): """Engine that's being accessed is not running. From 8dff4f45e52f46c24fccb4014c1b136a9271cf51 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 19 Jun 2023 11:45:53 +0300 Subject: [PATCH 04/18] update database service --- src/firebolt/model/binding.py | 38 ------- src/firebolt/service/database.py | 167 ++++++++++++++++++++----------- src/firebolt/service/engine.py | 27 ++--- 3 files changed, 124 insertions(+), 108 deletions(-) delete mode 100644 src/firebolt/model/binding.py diff --git a/src/firebolt/model/binding.py b/src/firebolt/model/binding.py deleted file mode 100644 index 26407d44a78..00000000000 --- a/src/firebolt/model/binding.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel, Field - -from firebolt.model import FireboltBaseModel - - -class BindingKey(BaseModel): - account_id: str - database_id: str - engine_id: str - - -class Binding(FireboltBaseModel): - """A binding between an engine and a database.""" - - binding_key: BindingKey = Field(alias="id") - is_default_engine: bool = Field(alias="engine_is_default") - - # optional - current_status: Optional[str] - health_status: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - - @property - def database_id(self) -> str: - return self.binding_key.database_id - - @property - def engine_id(self) -> str: - return self.binding_key.engine_id diff --git a/src/firebolt/service/database.py b/src/firebolt/service/database.py index fd64284bb50..a83c39ba861 100644 --- a/src/firebolt/service/database.py +++ b/src/firebolt/service/database.py @@ -2,49 +2,58 @@ from typing import List, Optional, Union from firebolt.model.database import Database +from firebolt.model.engine import Engine from firebolt.service.base import BaseService -from firebolt.service.types import DatabaseOrder -from firebolt.utils.urls import ( - ACCOUNT_DATABASE_BY_NAME_URL, - ACCOUNT_DATABASE_URL, - ACCOUNT_DATABASES_URL, -) -from firebolt.utils.util import prune_dict +from firebolt.utils.exception import DatabaseNotFoundError logger = logging.getLogger(__name__) class DatabaseService(BaseService): - def get(self, id_: str) -> Database: - """Get a Database from Firebolt by its ID.""" - - response = self.client.get( - url=ACCOUNT_DATABASE_URL.format(account_id=self.account_id, database_id=id_) - ) - return Database._from_dict(response.json()["database"], self) - - def get_by_name(self, name: str) -> Database: - """Get a database from Firebolt by its name.""" - - database_id = self.get_id_by_name(name=name) - return self.get(id_=database_id) - - def get_id_by_name(self, name: str) -> str: - """Get a database ID from Firebolt by its name.""" - - response = self.client.get( - url=ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=self.account_id), - params={"database_name": name}, - ) - database_id = response.json()["database_id"]["database_id"] - return database_id + DB_FIELDS = ( + "database_name", + "description", + "region", + "status", + "data_size_full", + "data_size_compressed", + "attached_engines", + "created_on", + "created_by", + "errors", + ) + GET_SQL = f"SELECT {', '.join(DB_FIELDS)} FROM information_schema.databases" + GET_BY_NAME_SQL = GET_SQL + " WHERE database_name=?" + GET_WHERE_SQL = " WHERE " + + CREATE_PREFIX_SQL = "CREATE DATABASE {}" + CREATE_WITH_SQL = " WITH " + IF_NOT_EXISTS_SQL = "IF NOT EXISTS " + CREATE_PARAMETER_NAMES = ( + "REGION", + "ATTACHED_ENGINES", + "DESCRIPTION", + ) + + def _get_dict(self, name: str) -> dict: + with self._connection.cursor() as c: + count = c.execute(self.GET_BY_NAME_SQL, (name,)) + if count == 0: + raise DatabaseNotFoundError(name) + return { + column.name: value for column, value in zip(c.description, c.fetchone()) + } + + def get(self, name: str) -> Database: + """Get a Database from Firebolt by its name.""" + return Database._from_dict(self._get_dict(name), self) def get_many( self, name_contains: Optional[str] = None, attached_engine_name_eq: Optional[str] = None, attached_engine_name_contains: Optional[str] = None, - order_by: Optional[Union[str, DatabaseOrder]] = None, + region_eq: Optional[str] = None, ) -> List[Database]: """ Get a list of databases on Firebolt. @@ -54,49 +63,93 @@ def get_many( attached_engine_name_eq: Filter for databases by an exact engine name attached_engine_name_contains: Filter for databases by engines with a name containing this substring - order_by: Method by which to order the results. - See :py:class:`firebolt.service.types.DatabaseOrder` + region_eq: Filter for database by region Returns: A list of databases matching the filters """ - - if isinstance(order_by, str): - order_by = DatabaseOrder[order_by].name - - params = { - "page.first": "1000", - "order_by": order_by, - "filter.name_contains": name_contains, - "filter.attached_engine_name_eq": attached_engine_name_eq, - "filter.attached_engine_name_contains": attached_engine_name_contains, - } - - response = self.client.get( - url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), - params=prune_dict(params), - ) - - return [Database._from_dict(d["node"], self) for d in response.json()["edges"]] + sql = self.GET_SQL + parameters = [] + if any( + ( + name_contains, + attached_engine_name_eq, + attached_engine_name_contains, + region_eq, + ) + ): + condition = [] + if name_contains: + condition.append("database_name like ?") + parameters.append(f"%{name_contains}%") + if attached_engine_name_eq: + condition.append( + "any_match(eng -> split_part(eng, ' ', 1) = ?," + " split(',', attached_engines))" + ) + parameters.append(attached_engine_name_eq) + if attached_engine_name_contains: + condition.append( + "any_match(eng -> split_part(eng, ' ', 1) like ?," + " split(',', attached_engines))" + ) + parameters.append(f"%{attached_engine_name_contains}%") + if region_eq: + condition.append("region = ?") + parameters.append(str(region_eq)) + sql += self.GET_WHERE_SQL + " AND ".join(condition) + + with self._connection.cursor() as c: + c.execute(sql, parameters) + dicts = [ + {column.name: value for column, value in zip(c.description, row)} + for row in c.fetchall() + ] + return [Database._from_dict(_dict, self) for _dict in dicts] def create( - self, name: str, region: Optional[str] = None, description: Optional[str] = None + self, + name: str, + region: Optional[str] = None, + attached_engines: Union[List[str], List[Engine], None] = None, + description: Optional[str] = None, + fail_if_exists: bool = True, ) -> Database: """ Create a new Database on Firebolt. Args: name: Name of the database + attached_engines: List of engines to attach to the database region: Region name in which to create the database + fail_if_exists: Fail is a database with provided name already exists Returns: The newly created database """ logger.info(f"Creating Database (name={name})") - response = self.client.post( - url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), - headers={"Content-type": "application/json"}, - json={}, + + sql = self.CREATE_PREFIX_SQL.format( + ("" if fail_if_exists else self.IF_NOT_EXISTS_SQL), name ) - return Database._from_dict(response.json()["database"], self) + parameters = [] + if any((region, attached_engines, description)): + sql += self.CREATE_WITH_SQL + for name, value in zip( + self.CREATE_PARAMETER_NAMES, + (region, attached_engines, description), + ): + if value: + sql += f"{name} = ? " + # Convert list of engines to a list of their names + if ( + isinstance(value, list) + and len(value) > 0 + and isinstance(value[0], Engine) + ): + value = [eng.name for eng in value] # type: ignore + parameters.append(value) + with self._connection.cursor() as c: + c.execute(sql, parameters) + return self.get(name) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index 768ec1a856a..91f70eadc7c 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -11,7 +11,7 @@ class EngineService(BaseService): - ENGINE_DB_FIELDS = ( + DB_FIELDS = ( "engine_name", "region", "spec", @@ -25,11 +25,12 @@ class EngineService(BaseService): "type", "provisioning", ) - GET_SQL = f"SELECT {', '.join(ENGINE_DB_FIELDS)} FROM information_schema.engines" + GET_SQL = f"SELECT {', '.join(DB_FIELDS)} FROM information_schema.engines" GET_BY_NAME_SQL = GET_SQL + " WHERE engine_name=?" GET_WHERE_SQL = " WHERE " - CREATE_PREFIX_SQL = "CREATE ENGINE {}" + CREATE_PREFIX_SQL = "CREATE ENGINE {}{}" + IF_NOT_EXISTS_SQL = "IF NOT EXISTS " CREATE_WITH_SQL = " WITH " CREATE_PARAMETER_NAMES = ( "REGION", @@ -53,8 +54,7 @@ def _get_dict(self, name: str) -> dict: def get(self, name: str) -> Engine: """Get an engine from Firebolt by its name.""" - engine_dict = self._get_dict(name) - return Engine._from_dict(engine_dict, self) + return Engine._from_dict(self._get_dict(name), self) def get_many( self, @@ -71,7 +71,6 @@ def get_many( current_status_eq: Filter for engines with this status current_status_not_eq: Filter for engines that do not have this status region_eq: Filter for engines by region - order_by: Method by which to order the results. See [EngineOrder] Returns: A list of engines matching the filters @@ -96,13 +95,11 @@ def get_many( with self._connection.cursor() as c: c.execute(sql, parameters) - engine_dicts = [ - {column.name: value for column, value in zip(c.description, engine_row)} - for engine_row in c.fetchall() - ] - return [ - Engine._from_dict(engine_dict, self) for engine_dict in engine_dicts + dicts = [ + {column.name: value for column, value in zip(c.description, row)} + for row in c.fetchall() ] + return [Engine._from_dict(_dict, self) for _dict in dicts] def create( self, @@ -113,6 +110,7 @@ def create( scale: Optional[int] = None, auto_stop: Optional[int] = None, warmup: Union[str, WarmupMethod, None] = None, + fail_if_exists: bool = True, ) -> Engine: """ Create a new engine. @@ -135,6 +133,7 @@ def create( `PRELOAD_ALL_DATA` - Full data auto-load (both indexes and table data - full warmup) + fail_if_exists: Fail is an engine with provided name already exists Returns: Engine with the specified settings @@ -142,7 +141,9 @@ def create( logger.info(f"Creating engine {name}") - sql = self.CREATE_PREFIX_SQL.format(name) + sql = self.CREATE_PREFIX_SQL.format( + ("" if fail_if_exists else self.IF_NOT_EXISTS_SQL), name + ) parameters = [] if any((region, engine_type, spec, scale, auto_stop, warmup)): sql += self.CREATE_WITH_SQL From ae33e7f1a1fe767aa9644f2c1fcd45513625dae8 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 19 Jun 2023 12:44:29 +0300 Subject: [PATCH 05/18] update database model --- src/firebolt/model/database.py | 155 ++++++++------------------------- src/firebolt/service/engine.py | 14 ++- 2 files changed, 51 insertions(+), 118 deletions(-) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index dc778ab0f63..0569be05523 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -1,37 +1,21 @@ from __future__ import annotations import logging -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, List, Optional, Sequence - -from pydantic import Field +from typing import TYPE_CHECKING, ClassVar, List from firebolt.model import FireboltBaseModel -from firebolt.model.region import RegionKey from firebolt.service.types import EngineStatus from firebolt.utils.exception import AttachedEngineInUseError -from firebolt.utils.urls import ACCOUNT_DATABASE_URL if TYPE_CHECKING: - from firebolt.model.binding import Binding from firebolt.model.engine import Engine from firebolt.service.database import DatabaseService logger = logging.getLogger(__name__) -@dataclass -class DatabaseKey: - account_id: str - database_id: str - - -@dataclass -class FieldMask: - paths: Sequence[str] = Field(alias="paths") - - @dataclass class Database(FireboltBaseModel): """ @@ -41,134 +25,71 @@ class Database(FireboltBaseModel): but otherwise are not configurable. """ + ALTER_SQL: ClassVar[str] = "ALTER DATABASE {} SET DESCRIPTION = ?" + + DROP_SQL: ClassVar[str] = "DROP DATABASE {}" + # internal _service: DatabaseService = field() # required name: str = field(metadata={"db_name": "database_name"}) description: str = field() - compute_region_key: RegionKey = field() - - # optional - database_key: Optional[DatabaseKey] = field(default=None) - emoji: Optional[str] = field(default=None) - current_status: Optional[str] = field(default=None) - health_status: Optional[str] = field(default=None) - data_size_full: Optional[int] = field(default=None) - data_size_compressed: Optional[int] = field(default=None) - is_system_database: Optional[bool] = field(default=None) - storage_bucket_name: Optional[str] = field(default=None) - create_time: Optional[datetime] = field(default=None) - create_actor: Optional[str] = field(default=None) - last_update_time: Optional[datetime] = field(default=None) - last_update_actor: Optional[str] = field(default=None) - desired_status: Optional[str] = field(default=None) - - @property - def database_id(self) -> Optional[str]: - if self.database_key is None: - return None - return self.database_key.database_id + region: str = field() + data_size_full: int = field() + data_size_compressed: int = field() + _attached_engine_names: List[str] = field(metadata={"db_name": "attached_engines"}) + create_time: datetime = field(metadata={"db_name": "created_at"}) + create_actor: str = field(metadata={"db_name": "created_by"}) + _errors: str = field(metadata={"db_name": "errors"}) def get_attached_engines(self) -> List[Engine]: """Get a list of engines that are attached to this database.""" + return self._service.resource_manager.engines.get_many(database_name=self.name) - return self._service.resource_manager.bindings.get_engines_bound_to_database( # noqa: E501 - database=self - ) - - def attach_to_engine( - self, engine: Engine, is_default_engine: bool = False - ) -> Binding: + def attach_engine(self, engine: Engine) -> None: """ Attach an engine to this database. Args: engine: The engine to attach. - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. """ - - return self._service.resource_manager.bindings.create( - engine=engine, database=self, is_default_engine=is_default_engine + return self._service.resource_manager.engines.attach_to_database( + engine.name, self.name ) - def delete(self) -> Database: + def update(self, description: str) -> Database: """ - Delete a database from Firebolt. - - Raises an error if there are any attached engines. + Updates a database description. """ + if not description: + return self for engine in self.get_attached_engines(): if engine.current_status in { EngineStatus.STARTING, EngineStatus.STOPPING, }: - raise AttachedEngineInUseError(method_name="delete") - - logger.info( - f"Deleting Database (database_id={self.database_id}, name={self.name})" - ) - response = self._service.client.delete( - url=ACCOUNT_DATABASE_URL.format( - account_id=self._service.account_id, database_id=self.database_id - ), - headers={"Content-type": "application/json"}, - ) - return Database._from_dict(response.json()["database"], self._service) - - def update(self, description: str) -> Database: - """ - Updates a database description. - """ - - @dataclass - class _DatabaseUpdateRequest: - """Helper model for sending Database creation requests.""" - - account_id: str = field() - database: Database = field() - database_id: Optional[str] = field() - update_mask: FieldMask = field() + raise AttachedEngineInUseError(method_name="update") + with self._service._connection.cursor() as c: + c.execute(self.ALTER_SQL, (description,)) self.description = description + return self - logger.info( - f"Updating Database (database_id={self.database_id}, " - f"name={self.name}, description={self.description})" - ) - - payload = asdict( - _DatabaseUpdateRequest( - account_id=self._service.account_id, - database=self, - database_id=self.database_id, - update_mask=FieldMask(paths=["description"]), - ) - ) + def delete(self) -> None: + """ + Delete a database from Firebolt. - response = self._service.client.patch( - url=ACCOUNT_DATABASE_URL.format( - account_id=self._service.account_id, database_id=self.database_id - ), - headers={"Content-type": "application/json"}, - json=payload, - ) + Raises an error if there are any attached engines. + """ - return Database._from_dict(response.json()["database"], self._service) + for engine in self.get_attached_engines(): + if engine.current_status in { + EngineStatus.STARTING, + EngineStatus.STOPPING, + }: + raise AttachedEngineInUseError(method_name="delete") - def get_default_engine(self) -> Optional[Engine]: - """ - Returns: default engine of the database, or None if default engine is missing - """ - rm = self._service.resource_manager - default_engines = [ - rm.engines.get(binding.engine_id) - for binding in rm.bindings.get_many(database_id=self.database_id) - if binding.is_default_engine - ] - - return None if len(default_engines) == 0 else default_engines[0] + with self._service._connection.cursor() as c: + c.execute(self.DROP_SQL.format(self.name)) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index 91f70eadc7c..0d45b72fc9b 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -62,6 +62,7 @@ def get_many( current_status_eq: Union[str, EngineStatus, None] = None, current_status_not_eq: Union[str, EngineStatus, None] = None, region_eq: Union[str, Region, None] = None, + database_name: Optional[str] = None, ) -> List[Engine]: """ Get a list of engines on Firebolt. @@ -77,7 +78,15 @@ def get_many( """ sql = self.GET_SQL parameters = [] - if any((name_contains, current_status_eq, current_status_not_eq, region_eq)): + if any( + ( + name_contains, + current_status_eq, + current_status_not_eq, + region_eq, + database_name, + ) + ): condition = [] if name_contains: condition.append("engine_name like ?") @@ -91,6 +100,9 @@ def get_many( if region_eq: condition.append("region = ?") parameters.append(str(region_eq)) + if database_name: + condition.append("attached_to = ?") + parameters.append(database_name) sql += self.GET_WHERE_SQL + " AND ".join(condition) with self._connection.cursor() as c: From be8b3061854e83796efd8d985f80dc34dcdba75c Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 19 Jun 2023 14:09:25 +0300 Subject: [PATCH 06/18] minor issues fix --- src/firebolt/model/database.py | 8 +++++--- src/firebolt/model/engine.py | 16 ++-------------- src/firebolt/service/database.py | 8 ++++---- src/firebolt/service/engine.py | 4 ++-- src/firebolt/service/manager.py | 4 ---- 5 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 0569be05523..28ddc54d571 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -25,7 +25,7 @@ class Database(FireboltBaseModel): but otherwise are not configurable. """ - ALTER_SQL: ClassVar[str] = "ALTER DATABASE {} SET DESCRIPTION = ?" + ALTER_SQL: ClassVar[str] = "ALTER DATABASE {} WITH DESCRIPTION = ?" DROP_SQL: ClassVar[str] = "DROP DATABASE {}" @@ -36,10 +36,11 @@ class Database(FireboltBaseModel): name: str = field(metadata={"db_name": "database_name"}) description: str = field() region: str = field() + _status: str = field(metadata={"db_name": "status"}) data_size_full: int = field() data_size_compressed: int = field() _attached_engine_names: List[str] = field(metadata={"db_name": "attached_engines"}) - create_time: datetime = field(metadata={"db_name": "created_at"}) + create_time: datetime = field(metadata={"db_name": "created_on"}) create_actor: str = field(metadata={"db_name": "created_by"}) _errors: str = field(metadata={"db_name": "errors"}) @@ -72,8 +73,9 @@ def update(self, description: str) -> Database: }: raise AttachedEngineInUseError(method_name="update") + sql = self.ALTER_SQL.format(self.name) with self._service._connection.cursor() as c: - c.execute(self.ALTER_SQL, (description,)) + c.execute(sql, (description,)) self.description = description return self diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index aa548bad7e8..90c5b5a5e16 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -201,11 +201,11 @@ def update( sql = self.ALTER_PREFIX_SQL.format(self.name) parameters = [] - for name, value in zip( + for param, value in zip( self.ALTER_PARAMETER_NAMES, (scale, spec, auto_stop, name, warmup) ): if value: - sql += f"{name} = ? " + sql += f"{param} = ? " parameters.append(value) with self._service._connection.cursor() as c: @@ -213,18 +213,6 @@ def update( self.refresh() return self - @check_attached_to_database - def restart(self) -> Engine: - """ - Restart an engine. - - Returns: - The updated engine from Firebolt. - """ - self.stop() - self.start() - return self - def delete(self) -> None: """Delete an engine.""" self.refresh() diff --git a/src/firebolt/service/database.py b/src/firebolt/service/database.py index a83c39ba861..25f32a3aa9c 100644 --- a/src/firebolt/service/database.py +++ b/src/firebolt/service/database.py @@ -26,7 +26,7 @@ class DatabaseService(BaseService): GET_BY_NAME_SQL = GET_SQL + " WHERE database_name=?" GET_WHERE_SQL = " WHERE " - CREATE_PREFIX_SQL = "CREATE DATABASE {}" + CREATE_PREFIX_SQL = "CREATE DATABASE {}{}" CREATE_WITH_SQL = " WITH " IF_NOT_EXISTS_SQL = "IF NOT EXISTS " CREATE_PARAMETER_NAMES = ( @@ -128,7 +128,7 @@ def create( The newly created database """ - logger.info(f"Creating Database (name={name})") + logger.info(f"Creating database {name}") sql = self.CREATE_PREFIX_SQL.format( ("" if fail_if_exists else self.IF_NOT_EXISTS_SQL), name @@ -136,12 +136,12 @@ def create( parameters = [] if any((region, attached_engines, description)): sql += self.CREATE_WITH_SQL - for name, value in zip( + for param, value in zip( self.CREATE_PARAMETER_NAMES, (region, attached_engines, description), ): if value: - sql += f"{name} = ? " + sql += f"{param} = ? " # Convert list of engines to a list of their names if ( isinstance(value, list) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index 0d45b72fc9b..77efb196082 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -159,12 +159,12 @@ def create( parameters = [] if any((region, engine_type, spec, scale, auto_stop, warmup)): sql += self.CREATE_WITH_SQL - for name, value in zip( + for param, value in zip( self.CREATE_PARAMETER_NAMES, (region, engine_type, spec, scale, auto_stop, warmup), ): if value: - sql += f"{name} = ? " + sql += f"{param} = ? " parameters.append(value) with self._connection.cursor() as c: c.execute(sql, parameters) diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index 971d4056d92..2773609022e 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -121,10 +121,8 @@ def __init__( def _init_services(self) -> None: # avoid circular import - from firebolt.service.binding import BindingService from firebolt.service.database import DatabaseService from firebolt.service.engine import EngineService - from firebolt.service.engine_revision import EngineRevisionService from firebolt.service.instance_type import InstanceTypeService from firebolt.service.region import RegionService @@ -136,8 +134,6 @@ def _init_services(self) -> None: # Firebolt Resources self.databases = DatabaseService(resource_manager=self) self.engines = EngineService(resource_manager=self) - self.engine_revisions = EngineRevisionService(resource_manager=self) - self.bindings = BindingService(resource_manager=self) def __del__(self) -> None: if hasattr(self, "_client"): From 86058dce983306b67da7a93f3b8b15af89ea01ab Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 21 Jun 2023 11:31:46 +0300 Subject: [PATCH 07/18] update instance_type management --- src/firebolt/model/__init__.py | 2 +- src/firebolt/model/database.py | 10 +-- src/firebolt/model/engine.py | 24 +++++-- src/firebolt/model/instance_type.py | 33 ++++----- src/firebolt/model/provider.py | 16 ----- src/firebolt/model/region.py | 23 ------- src/firebolt/service/base.py | 4 -- src/firebolt/service/engine.py | 14 ++-- src/firebolt/service/instance_type.py | 98 ++++++--------------------- src/firebolt/service/manager.py | 22 +----- src/firebolt/service/provider.py | 10 --- src/firebolt/service/region.py | 65 ------------------ src/firebolt/utils/exception.py | 7 ++ 13 files changed, 77 insertions(+), 251 deletions(-) delete mode 100644 src/firebolt/model/provider.py delete mode 100644 src/firebolt/model/region.py delete mode 100644 src/firebolt/service/provider.py delete mode 100644 src/firebolt/service/region.py diff --git a/src/firebolt/model/__init__.py b/src/firebolt/model/__init__.py index 28b6fc1f9ac..3dbb10801f5 100644 --- a/src/firebolt/model/__init__.py +++ b/src/firebolt/model/__init__.py @@ -9,7 +9,7 @@ @dataclass class FireboltBaseModel: - _service: BaseService = field() + _service: BaseService = field(repr=False) @classmethod def _get_field_overrides(cls) -> Dict[str, str]: diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 28ddc54d571..067259472fd 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -30,19 +30,21 @@ class Database(FireboltBaseModel): DROP_SQL: ClassVar[str] = "DROP DATABASE {}" # internal - _service: DatabaseService = field() + _service: DatabaseService = field(repr=False) # required name: str = field(metadata={"db_name": "database_name"}) description: str = field() region: str = field() - _status: str = field(metadata={"db_name": "status"}) + _status: str = field(repr=False, metadata={"db_name": "status"}) data_size_full: int = field() data_size_compressed: int = field() - _attached_engine_names: List[str] = field(metadata={"db_name": "attached_engines"}) + _attached_engine_names: List[str] = field( + repr=False, metadata={"db_name": "attached_engines"} + ) create_time: datetime = field(metadata={"db_name": "created_on"}) create_actor: str = field(metadata={"db_name": "created_by"}) - _errors: str = field(metadata={"db_name": "errors"}) + _errors: str = field(repr=False, metadata={"db_name": "errors"}) def get_attached_engines(self) -> List[Engine]: """Get a list of engines that are attached to this database.""" diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 90c5b5a5e16..92bbe392fd0 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -4,11 +4,20 @@ import logging import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Tuple, + Union, +) from firebolt.db import Connection, connect from firebolt.model import FireboltBaseModel from firebolt.model.database import Database +from firebolt.model.instance_type import InstanceType from firebolt.service.types import EngineStatus, WarmupMethod from firebolt.utils.exception import DatabaseNotFoundError @@ -53,14 +62,16 @@ class Engine(FireboltBaseModel): ) DROP_SQL: ClassVar[str] = "DROP ENGINE {}" - _service: EngineService = field() + _service: EngineService = field(repr=False) name: str = field(metadata={"db_name": "engine_name"}) region: str = field() spec: str = field() scale: int = field() current_status: str = field(metadata={"db_name": "status"}) - _database_name: Optional[str] = field(metadata={"db_name": "attached_to"}) + _database_name: Optional[str] = field( + repr=False, metadata={"db_name": "attached_to"} + ) version: str = field() endpoint: str = field(metadata={"db_name": "url"}) warmup: str = field() @@ -89,7 +100,8 @@ def attach_to_database(self, database_name: str) -> None: Args: database: Database to which the engine will be attached """ - return self._service.attach_to_database(self.name, database_name) + self._service.attach_to_database(self.name, database_name) + self._database_name = database_name @check_attached_to_database def get_connection(self) -> Connection: @@ -179,7 +191,7 @@ def update( self, name: Optional[str] = None, scale: Optional[int] = None, - spec: Optional[str] = None, + spec: Union[InstanceType, str, None] = None, auto_stop: Optional[int] = None, warmup: Optional[WarmupMethod] = None, ) -> Engine: @@ -206,7 +218,7 @@ def update( ): if value: sql += f"{param} = ? " - parameters.append(value) + parameters.append(str(value)) with self._service._connection.cursor() as c: c.execute(sql, parameters) diff --git a/src/firebolt/model/instance_type.py b/src/firebolt/model/instance_type.py index 5ead1b30477..4d3fdc36400 100644 --- a/src/firebolt/model/instance_type.py +++ b/src/firebolt/model/instance_type.py @@ -1,26 +1,21 @@ +from dataclasses import dataclass, field from datetime import datetime -from typing import Optional - -from pydantic import Field +from typing import Dict from firebolt.model import FireboltBaseModel -class InstanceTypeKey(FireboltBaseModel): # type: ignore - provider_id: str - region_id: str - instance_type_id: str - - +@dataclass class InstanceType(FireboltBaseModel): - key: InstanceTypeKey = Field(alias="id") - name: str + _key: Dict = field(repr=False, metadata={"db_name": "id"}) + name: str = field() + is_spot_available: bool = field() + cpu_virtual_cores_count: int = field() + memory_size_bytes: int = field() + storage_size_bytes: int = field() + price_per_hour_cents: float = field() + create_time: datetime = field() + last_update_time: datetime = field() - # optional - is_spot_available: Optional[bool] - cpu_virtual_cores_count: Optional[int] - memory_size_bytes: Optional[int] - storage_size_bytes: Optional[int] - price_per_hour_cents: Optional[float] - create_time: Optional[datetime] - last_update_time: Optional[datetime] + def __str__(self) -> str: + return self.name diff --git a/src/firebolt/model/provider.py b/src/firebolt/model/provider.py deleted file mode 100644 index 3cc041c2710..00000000000 --- a/src/firebolt/model/provider.py +++ /dev/null @@ -1,16 +0,0 @@ -from dataclasses import dataclass, field -from datetime import datetime -from typing import Optional - -from firebolt.model import FireboltBaseModel - - -@dataclass -class Provider(FireboltBaseModel): - provider_id: str = field(metadata={"db_name": "id"}) - name: str = field() - - # optional - create_time: Optional[datetime] = field(default=None) - display_name: Optional[str] = field(default=None) - last_update_time: Optional[datetime] = field(default=None) diff --git a/src/firebolt/model/region.py b/src/firebolt/model/region.py deleted file mode 100644 index b292180ce98..00000000000 --- a/src/firebolt/model/region.py +++ /dev/null @@ -1,23 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from firebolt.model import FireboltBaseModel - - -@dataclass -class RegionKey: - provider_id: str - region_id: str - - -class Region(FireboltBaseModel): - key: RegionKey = Field(alias="id") - name: str - - # optional - display_name: Optional[str] - create_time: Optional[datetime] - last_update_time: Optional[datetime] diff --git a/src/firebolt/service/base.py b/src/firebolt/service/base.py index e677098c8ec..bb892d3412d 100644 --- a/src/firebolt/service/base.py +++ b/src/firebolt/service/base.py @@ -19,10 +19,6 @@ def client(self) -> Client: def account_id(self) -> str: return self.resource_manager.account_id - @property - def _default_region(self) -> str: - return self.resource_manager.default_region - @property def _connection(self) -> Connection: return self.resource_manager._connection diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index 77efb196082..ecfae3a65b5 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -2,7 +2,7 @@ from typing import List, Optional, Union from firebolt.model.engine import Engine -from firebolt.model.region import Region +from firebolt.model.instance_type import InstanceType from firebolt.service.base import BaseService from firebolt.service.types import EngineStatus, EngineType, WarmupMethod from firebolt.utils.exception import EngineNotFoundError @@ -61,7 +61,7 @@ def get_many( name_contains: Optional[str] = None, current_status_eq: Union[str, EngineStatus, None] = None, current_status_not_eq: Union[str, EngineStatus, None] = None, - region_eq: Union[str, Region, None] = None, + region_eq: Optional[str] = None, database_name: Optional[str] = None, ) -> List[Engine]: """ @@ -99,7 +99,7 @@ def get_many( parameters.append(str(current_status_eq)) if region_eq: condition.append("region = ?") - parameters.append(str(region_eq)) + parameters.append(region_eq) if database_name: condition.append("attached_to = ?") parameters.append(database_name) @@ -116,9 +116,9 @@ def get_many( def create( self, name: str, - region: Union[str, Region, None] = None, + region: Optional[str] = None, engine_type: Union[str, EngineType] = EngineType.GENERAL_PURPOSE, - spec: Optional[str] = None, + spec: Union[InstanceType, str, None] = None, scale: Optional[int] = None, auto_stop: Optional[int] = None, warmup: Union[str, WarmupMethod, None] = None, @@ -165,7 +165,9 @@ def create( ): if value: sql += f"{param} = ? " - parameters.append(value) + if isinstance(value, EngineType): + value = value.value + parameters.append(str(value)) with self._connection.cursor() as c: c.execute(sql, parameters) return self.get(name) diff --git a/src/firebolt/service/instance_type.py b/src/firebolt/service/instance_type.py index 6bc779ef60f..47353e88f6b 100644 --- a/src/firebolt/service/instance_type.py +++ b/src/firebolt/service/instance_type.py @@ -1,19 +1,12 @@ -from typing import Dict, List, NamedTuple, Optional +from typing import List, Optional -from firebolt.model.instance_type import InstanceType, InstanceTypeKey -from firebolt.model.region import Region +from firebolt.model.instance_type import InstanceType from firebolt.service.base import BaseService +from firebolt.utils.exception import InstanceTypeNotFoundError from firebolt.utils.urls import ACCOUNT_INSTANCE_TYPES_URL from firebolt.utils.util import cached_property -class InstanceTypeLookup(NamedTuple): - """Helper tuple for looking up instance types by names.""" - - region_name: str - instance_type_name: str - - class InstanceTypeService(BaseService): @cached_property def instance_types(self) -> List[InstanceType]: @@ -23,90 +16,41 @@ def instance_types(self) -> List[InstanceType]: url=ACCOUNT_INSTANCE_TYPES_URL.format(account_id=self.account_id), params={"page.first": 5000}, ) - return [InstanceType._from_dict(i["node"]) for i in response.json()["edges"]] - - @cached_property - def instance_types_by_key(self) -> Dict[InstanceTypeKey, InstanceType]: - """Dict of {InstanceTypeKey to InstanceType}""" - return {i.key: i for i in self.instance_types} + # Only take one instance type with a specific name + instance_types, names = list(), set() + for it in [i["node"] for i in response.json()["edges"]]: + if it["name"] not in names: + names.add(it["name"]) + instance_types.append(InstanceType._from_dict(it, self)) + return instance_types @cached_property - def instance_types_by_name(self) -> Dict[InstanceTypeLookup, InstanceType]: - """Dict of {InstanceTypeLookup to InstanceType}""" - - return { - InstanceTypeLookup( - region_name=self.resource_manager.regions.get_by_id( - id_=i.key.region_id - ).name, - instance_type_name=i.name, - ): i - for i in self.instance_types - } - - def get_instance_types_per_region(self, region: Region) -> List[InstanceType]: - """List of instance types available on Firebolt in specified region.""" - - response = self.client.get( - url=ACCOUNT_INSTANCE_TYPES_URL.format(account_id=self.account_id), - params={"page.first": 5000, "filter.id_region_id_eq": region.key.region_id}, - ) - - instance_list = [ - InstanceType._from_dict(i["node"]) for i in response.json()["edges"] - ] - - # Filter out instances without storage - return [ - i - for i in instance_list - if i.storage_size_bytes and i.storage_size_bytes != "0" - ] - - def cheapest_instance_in_region(self, region: Region) -> Optional[InstanceType]: + def cheapest_instance(self) -> Optional[InstanceType]: # Get only available instances in region - instance_list = self.get_instance_types_per_region(region) - - if not instance_list: + if not self.instance_types: return None - cheapest = min( - instance_list, + return min( + self.instance_types, key=lambda x: x.price_per_hour_cents if x.price_per_hour_cents else float("Inf"), ) - return cheapest - - def get_by_key(self, instance_type_key: InstanceTypeKey) -> InstanceType: - """Get an instance type by key.""" - - return self.instance_types_by_key[instance_type_key] - def get_by_name( - self, - instance_type_name: str, - region_name: Optional[str] = None, - ) -> InstanceType: + def get(self, name: str) -> InstanceType: """ Get an instance type by name. Args: - instance_type_name: Name of the instance (eg. "i3.4xlarge") - region_name: - Name of the AWS region from which to get the instance. - If not provided, use the default region name from the client. + name: Name of the instance (eg. "i3.4xlarge") Returns: - The requested instance type + The requested instance type or None if it wasn't found """ # Will raise an error if neither set - region_name = region_name or self.resource_manager.regions.default_region.name - return self.instance_types_by_name[ - InstanceTypeLookup( - region_name=region_name, - instance_type_name=instance_type_name, - ) - ] + its = [it for it in self.instance_types if it.name == name] + if len(its) == 0: + raise InstanceTypeNotFoundError(name) + return its[0] diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index 2773609022e..4c0ed39e7cc 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -13,7 +13,6 @@ ) from firebolt.common import Settings from firebolt.db import connect -from firebolt.service.provider import get_provider_id from firebolt.utils.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 60 * 2 @@ -24,7 +23,7 @@ Using Settings objects for ResourceManager intialization is deprecated. Please pass parameters directly Example: - >>> rm = ResourceManager(auth=ClientCredentials(..), default_region="us-east-1", ..) + >>> rm = ResourceManager(auth=ClientCredentials(..), account_name="my_account", ..) """ @@ -34,12 +33,9 @@ class ResourceManager: - databases - engines - - bindings (the bindings between an engine and a database) - - engine revisions (versions of an engine) Also provides listings of: - - regions (AWS regions in which engines can run) - instance types (AWS instance types which engines can use) """ @@ -47,7 +43,6 @@ class ResourceManager: "account_name", "account_id", "api_endpoint", - "default_region", "_client", "_connection", "regions", @@ -64,30 +59,22 @@ def __init__( settings: Optional[Settings] = None, auth: Optional[Auth] = None, account_name: Optional[str] = None, - default_region: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, ): if settings: logger.warning(SETTINGS_DEPRECATION_MESSAGE) - if ( - auth - or account_name - or default_region - or (api_endpoint != DEFAULT_API_URL) - ): + if auth or account_name or (api_endpoint != DEFAULT_API_URL): raise ValueError( "Other ResourceManager parameters are not allowed " "when Settings are provided" ) auth = settings.auth account_name = settings.account_name - default_region = settings.default_region api_endpoint = settings.server for param, name in ( (auth, "auth"), (account_name, "account_name"), - (default_region, "default_region"), ): if not param: raise ValueError(f"Missing {name} value") @@ -95,7 +82,6 @@ def __init__( # type checks assert auth is not None assert account_name is not None - assert default_region is not None self._client = Client( auth=auth, @@ -116,7 +102,6 @@ def __init__( self.account_name = account_name self.api_endpoint = api_endpoint self.account_id = self._client.account_id - self.default_region = default_region self._init_services() def _init_services(self) -> None: @@ -124,12 +109,9 @@ def _init_services(self) -> None: from firebolt.service.database import DatabaseService from firebolt.service.engine import EngineService from firebolt.service.instance_type import InstanceTypeService - from firebolt.service.region import RegionService # Cloud Platform Resources (AWS) - self.regions = RegionService(resource_manager=self) self.instance_types = InstanceTypeService(resource_manager=self) - self._provider_id = get_provider_id(client=self._client) # Firebolt Resources self.databases = DatabaseService(resource_manager=self) diff --git a/src/firebolt/service/provider.py b/src/firebolt/service/provider.py deleted file mode 100644 index 129c8ef896e..00000000000 --- a/src/firebolt/service/provider.py +++ /dev/null @@ -1,10 +0,0 @@ -from firebolt.client import Client -from firebolt.model.provider import Provider -from firebolt.utils.urls import PROVIDERS_URL - - -def get_provider_id(client: Client) -> str: - """Get the AWS provider_id.""" - response = client.get(url=PROVIDERS_URL) - providers = [Provider._from_dict(i["node"]) for i in response.json()["edges"]] - return providers[0].provider_id diff --git a/src/firebolt/service/region.py b/src/firebolt/service/region.py deleted file mode 100644 index 2e9f99a63ca..00000000000 --- a/src/firebolt/service/region.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Dict, List - -from firebolt.model.region import Region, RegionKey -from firebolt.service.base import BaseService -from firebolt.service.manager import ResourceManager -from firebolt.utils.urls import REGIONS_URL -from firebolt.utils.util import cached_property - - -class RegionService(BaseService): - def __init__(self, resource_manager: ResourceManager): - """ - Service to manage AWS regions (us-east-1, etc) - - Args: - resource_manager: Resource manager to use - """ - - super().__init__(resource_manager=resource_manager) - - @cached_property - def regions(self) -> List[Region]: - """List of available AWS regions on Firebolt.""" - - response = self.client.get(url=REGIONS_URL, params={"page.first": 5000}) - return [Region._from_dict(i["node"], self) for i in response.json()["edges"]] - - @cached_property - def regions_by_name(self) -> Dict[str, Region]: - """Dict of {RegionLookup to Region}""" - - return {r.name: r for r in self.regions} - - @cached_property - def regions_by_key(self) -> Dict[RegionKey, Region]: - """Dict of {RegionKey to Region}""" - - return {r.key: r for r in self.regions} - - @cached_property - def default_region(self) -> Region: - """Default AWS region, could be provided from environment.""" - - if not self._default_region: - raise ValueError( - "The environment variable FIREBOLT_DEFAULT_REGION must be set." - ) - return self.get_by_name(name=self._default_region) - - def get_by_name(self, name: str) -> Region: - """Get an AWS region by its name (eg. us-east-1).""" - - return self.regions_by_name[name] - - def get_by_key(self, key: RegionKey) -> Region: - """Get an AWS region by its key.""" - - return self.regions_by_key[key] - - def get_by_id(self, id_: str) -> Region: - """Get an AWS region by region_id.""" - - return self.get_by_key( - RegionKey(provider_id=self.resource_manager._provider_id, region_id=id_) - ) diff --git a/src/firebolt/utils/exception.py b/src/firebolt/utils/exception.py index 28afda1926f..fee1d9f1044 100644 --- a/src/firebolt/utils/exception.py +++ b/src/firebolt/utils/exception.py @@ -27,6 +27,13 @@ def __init__(self, database_name: str): super().__init__(f"Database with name {database_name} was not found") +class InstanceTypeNotFoundError(FireboltError): + """Instance type with provided name was not found.""" + + def __init__(self, instance_type_name: str): + super().__init__(f"Instance type with name {instance_type_name} was not found") + + class NoAttachedDatabaseError(FireboltEngineError): """Engine that's being accessed is not running. From fdb97e9ce4005ea5109df7c1386e84994578213e Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 21 Jun 2023 15:25:04 +0300 Subject: [PATCH 08/18] fix instance_type test --- src/firebolt/model/__init__.py | 2 +- tests/unit/conftest.py | 81 ++----- tests/unit/service/conftest.py | 277 ++++------------------- tests/unit/service/test_engine.py | 4 +- tests/unit/service/test_instance_type.py | 31 +-- tests/unit/service/test_region.py | 25 -- tests/unit/util.py | 10 +- 7 files changed, 77 insertions(+), 353 deletions(-) delete mode 100644 tests/unit/service/test_region.py diff --git a/src/firebolt/model/__init__.py b/src/firebolt/model/__init__.py index 3dbb10801f5..24cacdd58ae 100644 --- a/src/firebolt/model/__init__.py +++ b/src/firebolt/model/__init__.py @@ -9,7 +9,7 @@ @dataclass class FireboltBaseModel: - _service: BaseService = field(repr=False) + _service: BaseService = field(repr=False, compare=False) @classmethod def _get_field_overrides(cls) -> Dict[str, str]: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9910471dc29..7e2e7aa548c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,5 +1,5 @@ from re import Pattern, compile -from typing import Callable, List +from typing import Callable import httpx from httpx import Request, Response @@ -7,9 +7,6 @@ from pytest import fixture from firebolt.client.auth import Auth, ClientCredentials -from firebolt.common.settings import Settings -from firebolt.model.provider import Provider -from firebolt.model.region import Region, RegionKey from firebolt.utils.exception import ( AccountNotFoundError, DatabaseError, @@ -88,61 +85,11 @@ def access_token_2() -> str: return "mock_access_token_2" -@fixture -def provider() -> Provider: - return Provider( - provider_id="mock_provider_id", - name="mock_provider_name", - ) - - -@fixture -def mock_providers(provider) -> List[Provider]: - return [provider] - - -@fixture -def region_1(provider) -> Region: - return Region( - key=RegionKey( - provider_id=provider.provider_id, - region_id="mock_region_id_1", - ), - name="mock_region_1", - ) - - -@fixture -def region_2(provider) -> Region: - return Region( - key=RegionKey( - provider_id=provider.provider_id, - region_id="mock_region_id_2", - ), - name="mock_region_2", - ) - - -@fixture -def mock_regions(region_1, region_2) -> List[Region]: - return [region_1, region_2] - - @fixture def auth(client_id: str, client_secret: str) -> Auth: return ClientCredentials(client_id, client_secret) -@fixture -def settings(server: str, region_1: str, auth: Auth, account_name: str) -> Settings: - return Settings( - server=server, - auth=auth, - default_region=region_1.name, - account_name=account_name, - ) - - @fixture def auth_callback(auth_url: str) -> Callable: def do_mock( @@ -174,9 +121,9 @@ def db_description() -> str: @fixture -def account_id_url(settings: Settings, account_name: str) -> Pattern: +def account_id_url(server: str, account_name: str) -> Pattern: account_name_re = r"[^\\\\]*" - base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}" + base = f"https://{server}{ACCOUNT_BY_NAME_URL}" base = base.replace("/", "\\/").replace("?", "\\?") base = base.format(account_name=account_name_re) return compile(base) @@ -185,13 +132,13 @@ def account_id_url(settings: Settings, account_name: str) -> Pattern: @fixture def account_id_callback( account_id: str, - settings: Settings, + account_name: str, ) -> Callable: def do_mock( request: Request, **kwargs, ) -> Response: - if request.url.path.split("/")[-2] != settings.account_name: + if request.url.path.split("/")[-2] != account_name: raise AccountNotFoundError(request.url.path.split("/")[-2]) return Response(status_code=httpx.codes.OK, json={"id": account_id}) @@ -214,22 +161,20 @@ def engine_name() -> str: @fixture -def get_engine_name_by_id_url( - settings: Settings, account_id: str, engine_id: str -) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( +def get_engine_name_by_id_url(server: str, account_id: str, engine_id: str) -> str: + return f"https://{server}" + ACCOUNT_ENGINE_URL.format( account_id=account_id, engine_id=engine_id ) @fixture -def get_engines_url(settings: Settings) -> str: - return f"https://{settings.server}{ENGINES_URL}" +def get_engines_url(server: str) -> str: + return f"https://{server}{ENGINES_URL}" @fixture -def get_databases_url(settings: Settings) -> str: - return f"https://{settings.server}{DATABASES_URL}" +def get_databases_url(server: str) -> str: + return f"https://{server}{DATABASES_URL}" @fixture @@ -238,9 +183,9 @@ def database_id() -> str: @fixture -def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> str: +def database_by_name_url(server: str, account_id: str, db_name: str) -> str: return ( - f"https://{settings.server}" + f"https://{server}" f"{ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id)}" f"?database_name={db_name}" ) diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 4cb7a7789a2..524e2c2af88 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -1,4 +1,5 @@ import json +from datetime import datetime from typing import Callable, List from urllib.parse import urlparse @@ -6,27 +7,16 @@ from httpx import Response from pytest import fixture -from firebolt.common.settings import Settings -from firebolt.model.binding import Binding, BindingKey -from firebolt.model.database import Database, DatabaseKey -from firebolt.model.engine import Engine, EngineKey, EngineSettings -from firebolt.model.engine_revision import ( - EngineRevision, - EngineRevisionSpecification, -) -from firebolt.model.instance_type import InstanceType, InstanceTypeKey -from firebolt.model.region import Region +from firebolt.model.database import Database +from firebolt.model.engine import Engine +from firebolt.model.instance_type import InstanceType from firebolt.utils.urls import ( - ACCOUNT_BINDINGS_URL, - ACCOUNT_DATABASE_BINDING_URL, ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_DATABASE_URL, ACCOUNT_DATABASES_URL, ACCOUNT_ENGINE_URL, ACCOUNT_INSTANCE_TYPES_URL, ACCOUNT_LIST_ENGINES_URL, - PROVIDERS_URL, - REGIONS_URL, ) from tests.unit.util import list_to_paginated_response @@ -42,76 +32,61 @@ def engine_scale() -> int: @fixture -def engine_settings() -> EngineSettings: - return EngineSettings.default() - - -@fixture -def mock_engine(engine_name, region_1, engine_settings, account_id, settings) -> Engine: +def mock_engine(engine_name, region_1, engine_settings, account_id, server) -> Engine: return Engine( name=engine_name, compute_region_key=region_1.key, settings=engine_settings, key=EngineKey(account_id=account_id, engine_id="mock_engine_id_1"), - endpoint=f"https://{settings.server}", - ) - - -@fixture -def mock_engine_revision_spec( - instance_type_2, engine_scale -) -> EngineRevisionSpecification: - return EngineRevisionSpecification( - db_compute_instances_type_key=instance_type_2.key, - db_compute_instances_count=engine_scale, - proxy_instances_type_key=instance_type_2.key, + endpoint=f"https://{server}", ) @fixture -def mock_engine_revision(mock_engine_revision_spec) -> EngineRevision: - return EngineRevision(specification=mock_engine_revision_spec) - - -@fixture -def instance_type_1(provider, region_1) -> InstanceType: +def instance_type_1() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_1.key.region_id, - instance_type_id="instance_type_id_1", - ), name="B1", - price_per_hour_cents=10, + price_per_hour_cents=40, storage_size_bytes=0, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @fixture -def instance_type_2(provider, region_2) -> InstanceType: +def instance_type_2() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_2.key.region_id, - instance_type_id="instance_type_id_2", - ), name="B2", price_per_hour_cents=20, storage_size_bytes=500, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @fixture -def instance_type_3(provider, region_2) -> InstanceType: +def instance_type_3() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_2.key.region_id, - instance_type_id="instance_type_id_2", - ), - name="B2", + name="B3", price_per_hour_cents=30, storage_size_bytes=500, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @@ -127,46 +102,6 @@ def mock_instance_types( return [instance_type_1, instance_type_2, instance_type_3] -@fixture -def provider_callback(provider_url: str, mock_providers) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == provider_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_providers), - ) - - return do_mock - - -@fixture -def provider_url(settings: Settings) -> str: - return f"https://{settings.server}{PROVIDERS_URL}" - - -@fixture -def region_callback(region_url: str, mock_regions) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == region_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_regions), - ) - - return do_mock - - -@fixture -def region_url(settings: Settings) -> str: - return f"https://{settings.server}{REGIONS_URL}?page.first=5000" - - @fixture def instance_type_callback(instance_type_url: str, mock_instance_types) -> Callable: def do_mock( @@ -182,23 +117,6 @@ def do_mock( return do_mock -@fixture -def instance_type_region_1_callback( - instance_type_region_1_url: str, mock_instance_types -) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == instance_type_region_1_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_instance_types), - ) - - return do_mock - - @fixture def instance_type_empty_callback() -> Callable: def do_mock( @@ -214,36 +132,14 @@ def do_mock( @fixture -def instance_type_url(settings: Settings, account_id: str) -> str: +def instance_type_url(server: str, account_id: str) -> str: return ( - f"https://{settings.server}" + f"https://{server}" + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) + "?page.first=5000" ) -@fixture -def instance_type_region_1_url( - settings: Settings, region_1: Region, account_id: str -) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_region_id_eq={region_1.key.region_id}" - ) - - -@fixture -def instance_type_region_2_url( - settings: Settings, region_2: Region, account_id: str -) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_region_id_eq={region_2.key.region_id}" - ) - - @fixture def engine_callback(engine_url: str, mock_engine) -> Callable: def do_mock( @@ -260,10 +156,8 @@ def do_mock( @fixture -def engine_url(settings: Settings, account_id) -> str: - return f"https://{settings.server}" + ACCOUNT_LIST_ENGINES_URL.format( - account_id=account_id - ) +def engine_url(server: str, account_id) -> str: + return f"https://{server}" + ACCOUNT_LIST_ENGINES_URL.format(account_id=account_id) @fixture @@ -282,8 +176,8 @@ def do_mock( @fixture -def account_engine_url(settings: Settings, account_id, mock_engine) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( +def account_engine_url(server: str, account_id, mock_engine) -> str: + return f"https://{server}" + ACCOUNT_ENGINE_URL.format( account_id=account_id, engine_id=mock_engine.engine_id, ) @@ -334,10 +228,8 @@ def get_databases_callback_inner( @fixture -def databases_url(settings: Settings, account_id: str) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASES_URL.format( - account_id=account_id - ) +def databases_url(server: str, account_id: str) -> str: + return f"https://{server}" + ACCOUNT_DATABASES_URL.format(account_id=account_id) @fixture @@ -371,8 +263,8 @@ def do_mock( @fixture -def database_url(settings: Settings, account_id: str, mock_database) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_URL.format( +def database_url(server: str, account_id: str, mock_database) -> str: + return f"https://{server}" + ACCOUNT_DATABASE_URL.format( account_id=account_id, database_id=mock_database.database_id ) @@ -393,9 +285,9 @@ def do_mock( @fixture -def database_get_by_name_url(settings: Settings, account_id: str, mock_database) -> str: +def database_get_by_name_url(server: str, account_id: str, mock_database) -> str: return ( - f"https://{settings.server}" + f"https://{server}" + ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id) + f"?database_name={mock_database.name}" ) @@ -435,84 +327,7 @@ def do_mock( # duplicates database_url @fixture -def database_get_url(settings: Settings, account_id: str, mock_database) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_URL.format( +def database_get_url(server: str, account_id: str, mock_database) -> str: + return f"https://{server}" + ACCOUNT_DATABASE_URL.format( account_id=account_id, database_id=mock_database.database_id ) - - -@fixture -def binding(account_id, mock_engine, mock_database) -> Binding: - return Binding( - binding_key=BindingKey( - account_id=account_id, - database_id=mock_database.database_id, - engine_id=mock_engine.engine_id, - ), - is_default_engine=True, - ) - - -@fixture -def bindings_callback(bindings_url: str, binding: Binding) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == bindings_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response([binding]), - ) - - return do_mock - - -@fixture -def no_bindings_callback(bindings_url: str) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == bindings_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response([]), - ) - - return do_mock - - -@fixture -def bindings_url(settings: Settings, account_id: str, mock_engine: Engine) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_BINDINGS_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_engine_id_eq={mock_engine.engine_id}" - ) - - -@fixture -def create_binding_callback(create_binding_url: str, binding) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == create_binding_url - return Response( - status_code=httpx.codes.OK, - json={"binding": binding.dict()}, - ) - - return do_mock - - -@fixture -def create_binding_url( - settings: Settings, account_id: str, mock_database: Database, mock_engine: Engine -) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_BINDING_URL.format( - account_id=account_id, - database_id=mock_database.database_id, - engine_id=mock_engine.engine_id, - ) diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 7abde2d8008..656e0579752 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -6,10 +6,8 @@ from pytest_httpx import HTTPXMock from firebolt.common import Settings -from firebolt.model.engine import Engine, _EngineCreateRequest -from firebolt.model.engine_revision import EngineRevision +from firebolt.model.engine import Engine from firebolt.model.instance_type import InstanceType -from firebolt.model.region import Region from firebolt.service.manager import ResourceManager from firebolt.utils.exception import FireboltError, NoAttachedDatabaseError diff --git a/tests/unit/service/test_instance_type.py b/tests/unit/service/test_instance_type.py index ef961ce9212..8cfefb789f5 100644 --- a/tests/unit/service/test_instance_type.py +++ b/tests/unit/service/test_instance_type.py @@ -2,43 +2,26 @@ from pytest_httpx import HTTPXMock -from firebolt.common import Settings +from firebolt.client.auth import Auth from firebolt.model.instance_type import InstanceType -from firebolt.model.region import Region from firebolt.service.manager import ResourceManager def test_instance_type( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, + auth: Auth, + account_name: str, + server: str, instance_type_callback: Callable, - instance_type_region_1_callback: Callable, - instance_type_empty_callback: Callable, instance_type_url: str, - instance_type_region_1_url: str, - instance_type_region_2_url: str, - settings: Settings, mock_instance_types: List[InstanceType], cheapest_instance: InstanceType, - region_1: Region, - region_2: Region, mock_system_engine_connection_flow: Callable, ): mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(instance_type_callback, url=instance_type_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback( - instance_type_empty_callback, url=instance_type_region_2_url - ) - manager = ResourceManager(settings=settings) + manager = ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) + assert manager.instance_types.instance_types == mock_instance_types - assert ( - manager.instance_types.cheapest_instance_in_region(region_1) - == cheapest_instance - ) - assert not manager.instance_types.cheapest_instance_in_region(region_2) + assert manager.instance_types.cheapest_instance == cheapest_instance diff --git a/tests/unit/service/test_region.py b/tests/unit/service/test_region.py deleted file mode 100644 index 6e6ebcda750..00000000000 --- a/tests/unit/service/test_region.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Callable, List - -from pytest_httpx import HTTPXMock - -from firebolt.common import Settings -from firebolt.model.region import Region -from firebolt.service.manager import ResourceManager - - -def test_region( - httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_regions: List[Region], - mock_system_engine_connection_flow: Callable, -): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(region_callback, url=region_url) - - manager = ResourceManager(settings=settings) - assert manager.regions.regions == mock_regions diff --git a/tests/unit/util.py b/tests/unit/util.py index 98b72304fc1..ebf869a00a0 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, fields from typing import AsyncGenerator, Dict, Generator, List from httpx import Request, Response @@ -7,8 +8,15 @@ from firebolt.model import FireboltBaseModel +def to_dict(dc: dataclass) -> Dict: + return { + (f.metadata or {}).get("db_name", f.name): getattr(dc, f.name) + for f in fields(dc) + } + + def list_to_paginated_response(items: List[FireboltBaseModel]) -> Dict: - return {"edges": [{"node": i.dict()} for i in items]} + return {"edges": [{"node": to_dict(i)} for i in items]} def execute_generator_requests( From e6b9753aedb5574c8bf5c5906a094420ae2f0bc1 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 21 Jun 2023 15:31:44 +0300 Subject: [PATCH 09/18] fix resource manager tests --- tests/unit/service/test_resource_manager.py | 61 ++++++--------------- 1 file changed, 17 insertions(+), 44 deletions(-) diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index d07c23044db..8afa351474b 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -6,7 +6,6 @@ from pytest_httpx import HTTPXMock from firebolt.client.auth import Auth, ClientCredentials -from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError from firebolt.utils.token_storage import TokenSecureStorage @@ -15,15 +14,10 @@ def test_rm_credentials( httpx_mock: HTTPXMock, + auth: Auth, + account_name: str, + server: str, check_token_callback: Callable, - check_credentials_callback: Callable, - settings: Settings, - auth_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - provider_callback: Callable, - provider_url: str, - access_token: str, mock_system_engine_connection_flow: Callable, ) -> None: """Credentials, that are passed to rm are processed properly.""" @@ -31,9 +25,8 @@ def test_rm_credentials( mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) - httpx_mock.add_callback(provider_callback, url=provider_url) - rm = ResourceManager(settings) + rm = ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) rm._client.get(url) @@ -41,14 +34,9 @@ def test_rm_credentials( def test_rm_token_cache( httpx_mock: HTTPXMock, check_token_callback: Callable, - check_credentials_callback: Callable, - settings: Settings, - auth_url: str, + auth: Auth, + server: str, account_name: str, - account_id_url: Pattern, - account_id_callback: Callable, - provider_callback: Callable, - provider_url: str, access_token: str, mock_system_engine_connection_flow: Callable, ) -> None: @@ -57,41 +45,32 @@ def test_rm_token_cache( mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) - httpx_mock.add_callback(provider_callback, url=provider_url) with Patcher(): - local_settings = Settings( - account_name=account_name, + rm = ResourceManager( auth=ClientCredentials( - settings.auth.client_id, - settings.auth.client_secret, - use_token_cache=True, + auth.client_id, auth.client_secret, use_token_cache=True ), - server=settings.server, - default_region=settings.default_region, + account_name=account_name, + api_endpoint=server, ) - rm = ResourceManager(local_settings) rm._client.get(url) - ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) + ts = TokenSecureStorage(auth.client_id, auth.client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): - local_settings = Settings( - account_name=account_name, + rm = ResourceManager( auth=ClientCredentials( - settings.auth.client_id, - settings.auth.client_secret, - use_token_cache=False, + auth.client_id, auth.client_secret, use_token_cache=False ), - server=settings.server, - default_region=settings.default_region, + account_name=account_name, + api_endpoint=server, ) - rm = ResourceManager(local_settings) rm._client.get(url) - ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) + ts = TokenSecureStorage(auth.client_id, auth.client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" @@ -101,7 +80,6 @@ def test_rm_invalid_account_name( httpx_mock: HTTPXMock, auth: Auth, server: str, - region_1: str, auth_url: str, check_credentials_callback: Callable, account_id_url: Pattern, @@ -119,9 +97,4 @@ def test_rm_invalid_account_name( httpx_mock.add_callback(account_id_callback, url=account_id_url) with raises(AccountNotFoundError): - ResourceManager( - auth=auth, - account_name="invalid", - api_endpoint=server, - default_region=region_1, - ) + ResourceManager(auth=auth, account_name="invalid", api_endpoint=server) From 4decbd131858732c81fab0ed71d34e2c766e32ac Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 21 Jun 2023 16:29:41 +0300 Subject: [PATCH 10/18] update database tests --- src/firebolt/model/database.py | 4 +- tests/unit/conftest.py | 5 -- tests/unit/service/conftest.py | 129 ++++++++++++++++++++-------- tests/unit/service/test_database.py | 99 ++++++++------------- tests/unit/util.py | 11 +-- 5 files changed, 138 insertions(+), 110 deletions(-) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 067259472fd..406140a81bd 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -30,7 +30,7 @@ class Database(FireboltBaseModel): DROP_SQL: ClassVar[str] = "DROP DATABASE {}" # internal - _service: DatabaseService = field(repr=False) + _service: DatabaseService = field(repr=False, compare=False) # required name: str = field(metadata={"db_name": "database_name"}) @@ -39,7 +39,7 @@ class Database(FireboltBaseModel): _status: str = field(repr=False, metadata={"db_name": "status"}) data_size_full: int = field() data_size_compressed: int = field() - _attached_engine_names: List[str] = field( + _attached_engine_names: str = field( repr=False, metadata={"db_name": "attached_engines"} ) create_time: datetime = field(metadata={"db_name": "created_on"}) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7e2e7aa548c..961fd859908 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -115,11 +115,6 @@ def db_name() -> str: return "database" -@fixture -def db_description() -> str: - return "database description" - - @fixture def account_id_url(server: str, account_name: str) -> Pattern: account_name_re = r"[^\\\\]*" diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 524e2c2af88..f6c3c20275c 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -1,4 +1,4 @@ -import json +from dataclasses import dataclass, fields from datetime import datetime from typing import Callable, List from urllib.parse import urlparse @@ -7,9 +7,12 @@ from httpx import Response from pytest import fixture +from firebolt.client.auth import Auth +from firebolt.common._types import _InternalType from firebolt.model.database import Database from firebolt.model.engine import Engine from firebolt.model.instance_type import InstanceType +from firebolt.service.manager import ResourceManager from firebolt.utils.urls import ( ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_DATABASE_URL, @@ -184,47 +187,75 @@ def account_engine_url(server: str, account_id, mock_engine) -> str: @fixture -def mock_database(region_1: str, account_id: str) -> Database: +def mock_database() -> Database: return Database( name="database", description="mock_db_description", - compute_region_key=region_1.key, - database_key=DatabaseKey( - account_id=account_id, database_id="mock_database_id_1" - ), + region="us-east-1", + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, + ) + + +@fixture +def mock_database_2() -> Database: + return Database( + name="database2", + description="completely different db", + region="us-east-1", + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, ) +empty_response = { + "meta": [], + "data": [], + "rows": 0, + "statistics": { + "elapsed": 39635.785423446, + "rows_read": 0, + "bytes_read": 0, + "time_before_execution": 0, + "time_to_execute": 0, + }, +} + + @fixture -def create_databases_callback(databases_url: str, mock_database) -> Callable: +def create_databases_callback( + system_engine_no_db_query_url: str, mock_database +) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - database_properties = json.loads(request.read().decode("utf-8"))["database"] - - mock_database.name = database_properties["name"] - mock_database.description = database_properties["description"] - - assert request.url == databases_url + assert request.url == system_engine_no_db_query_url return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json=empty_response, ) return do_mock @fixture -def databases_get_callback(databases_url: str, mock_database) -> Callable: - def get_databases_callback_inner( - request: httpx.Request = None, **kwargs - ) -> Response: - return Response( - status_code=httpx.codes.OK, json={"edges": [{"node": mock_database.dict()}]} - ) - - return get_databases_callback_inner +def databases_get_callback( + mock_database: Database, mock_database_2: Database +) -> Callable: + return get_objects_from_db_callback([mock_database, mock_database_2]) @fixture @@ -294,40 +325,64 @@ def database_get_by_name_url(server: str, account_id: str, mock_database) -> str @fixture -def database_update_callback(database_get_url, mock_database) -> Callable: +def database_update_callback() -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - database_properties = json.loads(request.read().decode("utf-8"))["database"] - - assert request.url == database_get_url return Response( status_code=httpx.codes.OK, - json={"database": database_properties}, + json=empty_response, ) return do_mock -@fixture -def database_get_callback(database_get_url, mock_database) -> Callable: +def get_objects_from_db_callback(objs: List[dataclass]) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == database_get_url + fieldname = lambda f: (f.metadata or {}).get("db_name", f.name) + types = { + "int": _InternalType.Long.value, + "str": _InternalType.Text.value, + "datetime": _InternalType.Text.value, # we receive datetime as text from db + } + dc_fields = [f for f in fields(objs[0]) if f.name != "_service"] + query_response = { + "meta": [{"name": fieldname(f), "type": types[f.type]} for f in dc_fields], + "data": [[getattr(obj, f.name) for f in dc_fields] for obj in objs], + "rows": len(objs), + "statistics": { + "elapsed": 0.116907717, + "rows_read": 1, + "bytes_read": 61, + "time_before_execution": 0.012180623, + "time_to_execute": 0.104614307, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json=query_response, ) return do_mock -# duplicates database_url @fixture -def database_get_url(server: str, account_id: str, mock_database) -> str: - return f"https://{server}" + ACCOUNT_DATABASE_URL.format( - account_id=account_id, database_id=mock_database.database_id - ) +def database_get_callback(mock_database) -> Callable: + return get_objects_from_db_callback([mock_database]) + + +@fixture +def resource_manager( + auth: Auth, + account_name: str, + server: str, + mock_system_engine_connection_flow: Callable, +) -> ResourceManager: + mock_system_engine_connection_flow() + return ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) diff --git a/tests/unit/service/test_database.py b/tests/unit/service/test_database.py index 01246743a19..e5bba41801e 100644 --- a/tests/unit/service/test_database.py +++ b/tests/unit/service/test_database.py @@ -1,110 +1,87 @@ -from re import compile from typing import Callable from pytest_httpx import HTTPXMock -from firebolt.common import Settings from firebolt.model.database import Database from firebolt.service.manager import ResourceManager def test_database_create( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, + resource_manager: ResourceManager, + database_get_callback: Callable, create_databases_callback: Callable, - databases_url: str, - db_name: str, - db_description: str, - mock_system_engine_connection_flow: Callable, + system_engine_no_db_query_url: str, + mock_database: Database, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") + httpx_mock.add_callback( + create_databases_callback, url=system_engine_no_db_query_url, method="POST" + ) + httpx_mock.add_callback( + database_get_callback, url=system_engine_no_db_query_url, method="POST" + ) - manager = ResourceManager(settings=settings) - database = manager.databases.create(name=db_name, description=db_description) + database = resource_manager.databases.create( + name=mock_database.name, description=mock_database.description + ) - assert database.name == db_name - assert database.description == db_description + assert database == mock_database -def test_database_get_by_name( +def test_database_get( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - settings: Settings, - database_get_by_name_callback: Callable, - database_get_by_name_url: str, + resource_manager: ResourceManager, database_get_callback: Callable, - database_get_url: str, + system_engine_no_db_query_url: str, mock_database: Database, - mock_system_engine_connection_flow: Callable, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(database_get_by_name_callback, url=database_get_by_name_url) - httpx_mock.add_callback(database_get_callback, url=database_get_url) + httpx_mock.add_callback( + database_get_callback, url=system_engine_no_db_query_url, method="POST" + ) - manager = ResourceManager(settings=settings) - database = manager.databases.get_by_name(name=mock_database.name) + database = resource_manager.databases.get(mock_database.name) - assert database.name == mock_database.name + assert database == mock_database def test_database_get_many( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - settings: Settings, - database_get_by_name_callback: Callable, - database_get_by_name_url: str, + resource_manager: ResourceManager, databases_get_callback: Callable, - databases_url: str, + system_engine_no_db_query_url: str, mock_database: Database, - mock_system_engine_connection_flow: Callable, + mock_database_2: Database, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( databases_get_callback, - url=compile(databases_url + "?[a-zA-Z0-9=&]*"), - method="GET", + url=system_engine_no_db_query_url, + method="POST", ) - manager = ResourceManager(settings=settings) - databases = manager.databases.get_many( + databases = resource_manager.databases.get_many( name_contains=mock_database.name, attached_engine_name_eq="mockengine", attached_engine_name_contains="mockengine", + region_eq="us-east-1", ) - assert len(databases) == 1 - assert databases[0].name == mock_database.name + assert len(databases) == 2 + assert databases[0] == mock_database + assert databases[1] == mock_database_2 def test_database_update( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - settings: Settings, + resource_manager: ResourceManager, database_update_callback: Callable, - database_url: str, + system_engine_no_db_query_url: str, mock_database: Database, - mock_system_engine_connection_flow: Callable, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - - httpx_mock.add_callback(database_update_callback, url=database_url, method="PATCH") - - manager = ResourceManager(settings=settings) + httpx_mock.add_callback( + database_update_callback, url=system_engine_no_db_query_url, method="POST" + ) - mock_database._service = manager.databases + mock_database._service = resource_manager.databases database = mock_database.update(description="new description") assert database.description == "new description" diff --git a/tests/unit/util.py b/tests/unit/util.py index ebf869a00a0..7c34650f15b 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, fields +from dataclasses import Field, dataclass, fields from typing import AsyncGenerator, Dict, Generator, List from httpx import Request, Response @@ -8,11 +8,12 @@ from firebolt.model import FireboltBaseModel +def field_name(f: Field) -> str: + return (f.metadata or {}).get("db_name", f.name) + + def to_dict(dc: dataclass) -> Dict: - return { - (f.metadata or {}).get("db_name", f.name): getattr(dc, f.name) - for f in fields(dc) - } + return {field_name(f): getattr(dc, f.name) for f in fields(dc)} def list_to_paginated_response(items: List[FireboltBaseModel]) -> Dict: From 6fd3f7b74258675fcdee6ea057ad513f063e3b00 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 22 Jun 2023 16:07:38 +0300 Subject: [PATCH 11/18] test fixes --- src/firebolt/model/database.py | 2 +- src/firebolt/model/engine.py | 65 ++++++++++++++++++++-------------- src/firebolt/service/engine.py | 14 +++++--- src/firebolt/service/types.py | 9 +++++ 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 406140a81bd..c7cb4698ab8 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -40,7 +40,7 @@ class Database(FireboltBaseModel): data_size_full: int = field() data_size_compressed: int = field() _attached_engine_names: str = field( - repr=False, metadata={"db_name": "attached_engines"} + repr=False, metadata={"db_name": "attached_engines"}, compare=False ) create_time: datetime = field(metadata={"db_name": "created_on"}) create_actor: str = field(metadata={"db_name": "created_by"}) diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 92bbe392fd0..cffdcce7291 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -18,8 +18,11 @@ from firebolt.model import FireboltBaseModel from firebolt.model.database import Database from firebolt.model.instance_type import InstanceType -from firebolt.service.types import EngineStatus, WarmupMethod -from firebolt.utils.exception import DatabaseNotFoundError +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.exception import ( + DatabaseNotFoundError, + NoAttachedDatabaseError, +) if TYPE_CHECKING: from firebolt.service.engine import EngineService @@ -32,8 +35,8 @@ def check_attached_to_database(func: Callable) -> Callable: @functools.wraps(func) def inner(self: Engine, *args: Any, **kwargs: Any) -> Any: - # if self.database is None: - # raise NoAttachedDatabaseError(method_name=func.__name__) + if self.database is None: + raise NoAttachedDatabaseError(method_name=func.__name__) return func(self, *args, **kwargs) return inner @@ -43,11 +46,6 @@ def inner(self: Engine, *args: Any, **kwargs: Any) -> Any: class Engine(FireboltBaseModel): """ A Firebolt engine. Responsible for performing work (queries, ingestion). - - Engines are configured in :py:class:`Settings - ` - and in :py:class:`EngineRevisionSpecification - `. """ START_SQL: ClassVar[str] = "START ENGINE {}" @@ -62,23 +60,35 @@ class Engine(FireboltBaseModel): ) DROP_SQL: ClassVar[str] = "DROP ENGINE {}" - _service: EngineService = field(repr=False) + _service: EngineService = field(repr=False, compare=False) name: str = field(metadata={"db_name": "engine_name"}) region: str = field() - spec: str = field() + spec: InstanceType = field() scale: int = field() - current_status: str = field(metadata={"db_name": "status"}) - _database_name: Optional[str] = field( - repr=False, metadata={"db_name": "attached_to"} - ) + current_status: EngineStatus = field(metadata={"db_name": "status"}) + _database_name: str = field(repr=False, metadata={"db_name": "attached_to"}) version: str = field() endpoint: str = field(metadata={"db_name": "url"}) - warmup: str = field() + warmup: WarmupMethod = field() auto_stop: int = field() - type: str = field() + type: EngineType = field() provisioning: str = field() + def __post_init__(self) -> None: + if isinstance(self.spec, str) and self.spec: + # Resolve engine specification + self.spec = self._service.resource_manager.instance_types.get(self.spec) + if isinstance(self.current_status, str) and self.current_status: + # Resolve engine status + self.current_status = EngineStatus(self.current_status) + if isinstance(self.warmup, str) and self.warmup: + # Resolve warmup method + self.warmup = WarmupMethod.from_display_name(self.warmup) + if isinstance(self.type, str) and self.type: + # Resolve engine type + self.type = EngineType.from_display_name(self.type) + @property def database(self) -> Optional[Database]: if self._database_name: @@ -90,18 +100,20 @@ def database(self) -> Optional[Database]: def refresh(self) -> None: """Update attributes of the instance from Firebolt.""" + field_name_overrides = self._get_field_overrides() for name, value in self._service._get_dict(self.name).items(): - setattr(self, name, value) + setattr(self, field_name_overrides.get(name, name), value) + + self.__post_init__() - def attach_to_database(self, database_name: str) -> None: + def attach_to_database(self, database: Union[Database, str]) -> None: """ Attach this engine to a database. Args: database: Database to which the engine will be attached """ - self._service.attach_to_database(self.name, database_name) - self._database_name = database_name + self._service.attach_to_database(self, database) @check_attached_to_database def get_connection(self) -> Connection: @@ -127,13 +139,13 @@ def _wait_for_start_stop(self) -> None: while self.current_status in (EngineStatus.STOPPING, EngineStatus.STARTING): logger.info( f"Engine {self.name} is currently " - f"{self.current_status.lower()}, waiting" + f"{self.current_status.value.lower()}, waiting" ) time.sleep(interval_seconds) if time.time() > timeout_time: raise TimeoutError( f"Excedeed timeout of {wait_timeout}s waiting for " - f"an engine in {self.current_status.lower()} state" + f"an engine in {self.current_status.value.lower()} state" ) logger.info(".[!n]") self.refresh() @@ -155,7 +167,7 @@ def start(self) -> Engine: if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): raise ValueError( f"Unable to start engine {self.name} because it's " - f"in {self.current_status.lower()} state" + f"in {self.current_status.value.lower()} state" ) logger.info(f"Starting engine {self.name}") @@ -179,7 +191,7 @@ def stop(self) -> Engine: if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): raise ValueError( f"Unable to stop engine {self.name} because it's " - f"in {self.current_status.lower()} state" + f"in {self.current_status.value.lower()} state" ) logger.info(f"Stopping engine {self.name}") with self._service._connection.cursor() as c: @@ -204,11 +216,12 @@ def update( # Nothing to be updated return self + self.refresh() self._wait_for_start_stop() if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): raise ValueError( f"Unable to update engine {self.name} because it's " - f"in {self.current_status.lower()} state" + f"in {self.current_status.value.lower()} state" ) sql = self.ALTER_PREFIX_SQL.format(self.name) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index ecfae3a65b5..e34982e5024 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import List, Optional, Union -from firebolt.model.engine import Engine +from firebolt.model.engine import Database, Engine from firebolt.model.instance_type import InstanceType from firebolt.service.base import BaseService from firebolt.service.types import EngineStatus, EngineType, WarmupMethod @@ -165,13 +165,19 @@ def create( ): if value: sql += f"{param} = ? " - if isinstance(value, EngineType): - value = value.value parameters.append(str(value)) with self._connection.cursor() as c: c.execute(sql, parameters) return self.get(name) - def attach_to_database(self, engine_name: str, database_name: str) -> None: + def attach_to_database( + self, engine: Union[Engine, str], database: Union[Database, str] + ) -> None: + engine_name = engine.name if isinstance(engine, Engine) else engine + database_name = database.name if isinstance(database, Database) else database with self._connection.cursor() as c: c.execute(self.ATTACH_TO_DB_SQL.format(engine_name, database_name)) + if isinstance(engine, Engine): + engine._database_name = ( + database.name if isinstance(database, Database) else database + ) diff --git a/src/firebolt/service/types.py b/src/firebolt/service/types.py index d72cbd183c0..022bb6ca549 100644 --- a/src/firebolt/service/types.py +++ b/src/firebolt/service/types.py @@ -12,6 +12,9 @@ def from_display_name(cls, display_name: str) -> "EngineType": "Analytics": cls.DATA_ANALYTICS, }[display_name] + def __str__(self) -> str: + return self.value + class WarmupMethod(Enum): MINIMAL = "MINIMAL" @@ -26,6 +29,9 @@ def from_display_name(cls, display_name: str) -> "WarmupMethod": "All": cls.PRELOAD_ALL_DATA, }[display_name] + def __str__(self) -> str: + return self.value + class EngineStatus(Enum): """ @@ -42,6 +48,9 @@ class EngineStatus(Enum): DROPPING = "Dropping" REPAIRING = "Repairing" + def __str__(self) -> str: + return self.value + class DatabaseOrder(Enum): DATABASE_ORDER_UNSPECIFIED = "DATABASE_ORDER_UNSPECIFIED" From 386ec107d671bd8e988b58a7a07a350badf450db Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 22 Jun 2023 16:09:01 +0300 Subject: [PATCH 12/18] fix unit tests --- tests/unit/async_db/conftest.py | 3 +- tests/unit/async_db/test_connection.py | 2 - tests/unit/client/test_client.py | 7 +- tests/unit/client/test_client_async.py | 7 +- tests/unit/db/test_connection.py | 2 - tests/unit/db_conftest.py | 9 +- tests/unit/service/conftest.py | 284 +++++++-------- tests/unit/service/test_database.py | 7 +- tests/unit/service/test_engine.py | 432 +++++------------------ tests/unit/service/test_instance_type.py | 12 +- 10 files changed, 232 insertions(+), 533 deletions(-) diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index 8be96d4ccf5..3eaa5f12da2 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -5,7 +5,6 @@ from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect from firebolt.client.auth import Auth -from firebolt.common.settings import Settings from tests.unit.db_conftest import * # noqa @@ -32,7 +31,7 @@ async def connection( @fixture -async def cursor(connection: Connection, settings: Settings) -> Cursor: +async def cursor(connection: Connection) -> Cursor: return connection.cursor() diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 0f704597909..d2a8dfaa4da 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -8,7 +8,6 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.settings import Settings from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -56,7 +55,6 @@ async def test_cursors_closed_on_close(connection: Connection) -> None: async def test_cursor_initialized( - settings: Settings, mock_query: Callable, connection: Connection, python_query_data: List[List[ColType]], diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index 0db63da2dd2..2eee11b62aa 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -9,7 +9,6 @@ from firebolt.client import Client from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx -from firebolt.common import Settings from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema @@ -96,7 +95,7 @@ def test_client_account_id( account_id_callback: Callable, auth_url: str, auth_callback: Callable, - settings: Settings, + server: str, ): httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(auth_callback, url=auth_url) @@ -104,8 +103,8 @@ def test_client_account_id( with Client( account_name=account_name, auth=auth, - base_url=fix_url_schema(settings.server), - api_endpoint=settings.server, + base_url=fix_url_schema(server), + api_endpoint=server, ) as c: assert c.account_id == account_id, "Invalid account id returned" diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index acad0c4a286..a1eefddc970 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -7,7 +7,6 @@ from firebolt.client import AsyncClient from firebolt.client.auth import Auth -from firebolt.common import Settings from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema @@ -99,7 +98,7 @@ async def test_client_account_id( account_id_callback: Callable, auth_url: str, auth_callback: Callable, - settings: Settings, + server: str, ): httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(auth_callback, url=auth_url) @@ -107,7 +106,7 @@ async def test_client_account_id( async with AsyncClient( account_name=account_name, auth=auth, - base_url=fix_url_schema(settings.server), - api_endpoint=settings.server, + base_url=fix_url_schema(server), + api_endpoint=server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 1399922558f..b58b4cfe78a 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -9,7 +9,6 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( ConfigurationError, @@ -58,7 +57,6 @@ def test_cursors_closed_on_close(connection: Connection) -> None: def test_cursor_initialized( - settings: Settings, mock_query: Callable, connection: Connection, python_query_data: List[List[ColType]], diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 745eb9287f1..b538a400ad3 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -8,7 +8,6 @@ from pytest_httpx import HTTPXMock from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column -from firebolt.common.settings import Settings from firebolt.db import ARRAY, DECIMAL from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME @@ -337,16 +336,16 @@ def set_params() -> Dict: @fixture -def query_url(settings: Settings, db_name: str) -> str: +def query_url(server: str, db_name: str) -> str: return URL( - f"https://{settings.server}/", + f"https://{server}/", params={"output_format": JSON_OUTPUT_FORMAT, "database": db_name}, ) @fixture -def set_query_url(settings: Settings, db_name: str) -> str: - return URL(f"https://{settings.server}/?database={db_name}") +def set_query_url(server: str, db_name: str) -> str: + return URL(f"https://{server}/?database={db_name}") @fixture diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index f6c3c20275c..8502d383806 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, fields from datetime import datetime from typing import Callable, List -from urllib.parse import urlparse import httpx from httpx import Response @@ -13,35 +12,32 @@ from firebolt.model.engine import Engine from firebolt.model.instance_type import InstanceType from firebolt.service.manager import ResourceManager -from firebolt.utils.urls import ( - ACCOUNT_DATABASE_BY_NAME_URL, - ACCOUNT_DATABASE_URL, - ACCOUNT_DATABASES_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_INSTANCE_TYPES_URL, - ACCOUNT_LIST_ENGINES_URL, -) +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.urls import ACCOUNT_INSTANCE_TYPES_URL from tests.unit.util import list_to_paginated_response @fixture -def engine_name() -> str: - return "my_engine" +def region() -> str: + return "us-east-1" @fixture -def engine_scale() -> int: - return 2 - - -@fixture -def mock_engine(engine_name, region_1, engine_settings, account_id, server) -> Engine: +def mock_engine(region: str, server: str, instance_type_1: InstanceType) -> Engine: return Engine( - name=engine_name, - compute_region_key=region_1.key, - settings=engine_settings, - key=EngineKey(account_id=account_id, engine_id="mock_engine_id_1"), - endpoint=f"https://{server}", + name="engine_1", + region=region, + spec=instance_type_1, + scale=2, + current_status=EngineStatus.STOPPED, + version="", + endpoint=server, + warmup=WarmupMethod.MINIMAL, + auto_stop=7200, + type=EngineType.GENERAL_PURPOSE, + provisioning="Finished", + _database_name="database", + _service=None, ) @@ -105,6 +101,40 @@ def mock_instance_types( return [instance_type_1, instance_type_2, instance_type_3] +@fixture +def mock_database(region: str) -> Database: + return Database( + name="database", + description="mock_db_description", + region=region, + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, + ) + + +@fixture +def mock_database_2(region: str) -> Database: + return Database( + name="database2", + description="completely different db", + region=region, + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, + ) + + @fixture def instance_type_callback(instance_type_url: str, mock_instance_types) -> Callable: def do_mock( @@ -143,101 +173,89 @@ def instance_type_url(server: str, account_id: str) -> str: ) -@fixture -def engine_callback(engine_url: str, mock_engine) -> Callable: +empty_response = { + "meta": [], + "data": [], + "rows": 0, + "statistics": { + "elapsed": 39635.785423446, + "rows_read": 0, + "bytes_read": 0, + "time_before_execution": 0, + "time_to_execute": 0, + }, +} + + +def get_objects_from_db_callback(objs: List[dataclass]) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert urlparse(engine_url).path in request.url.path + fieldname = lambda f: (f.metadata or {}).get("db_name", f.name) + types = { + "int": _InternalType.Long.value, + "str": _InternalType.Text.value, + "datetime": _InternalType.Text.value, # we receive datetime as text from db + "InstanceType": _InternalType.Text.value, + "EngineType": _InternalType.Text.value, + "EngineStatus": _InternalType.Text.value, + } + dc_fields = [f for f in fields(objs[0]) if f.name != "_service"] + + def get_obj_field(obj, f): + value = getattr(obj, f.name) + if isinstance(value, (InstanceType, EngineStatus)): + return str(value) + if isinstance(value, (EngineType, WarmupMethod)): + return " ".join( + map(lambda s: s.capitalize(), str(value).lower().split("_")) + ) + return value + + query_response = { + "meta": [{"name": fieldname(f), "type": types[f.type]} for f in dc_fields], + "data": [[get_obj_field(obj, f) for f in dc_fields] for obj in objs], + "rows": len(objs), + "statistics": { + "elapsed": 0.116907717, + "rows_read": 1, + "bytes_read": 61, + "time_before_execution": 0.012180623, + "time_to_execute": 0.104614307, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json=query_response, ) return do_mock @fixture -def engine_url(server: str, account_id) -> str: - return f"https://{server}" + ACCOUNT_LIST_ENGINES_URL.format(account_id=account_id) +def get_engine_callback(mock_engine: Engine) -> Callable: + return get_objects_from_db_callback([mock_engine]) @fixture -def account_engine_callback(account_engine_url: str, mock_engine) -> Callable: +def get_engine_not_found_callback(mock_engine: Engine) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == account_engine_url return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json=empty_response, ) return do_mock @fixture -def account_engine_url(server: str, account_id, mock_engine) -> str: - return f"https://{server}" + ACCOUNT_ENGINE_URL.format( - account_id=account_id, - engine_id=mock_engine.engine_id, - ) - - -@fixture -def mock_database() -> Database: - return Database( - name="database", - description="mock_db_description", - region="us-east-1", - data_size_full=0, - data_size_compressed=0, - create_time=datetime.now().isoformat(), - create_actor="", - _status="", - _attached_engine_names="-", - _errors="", - _service=None, - ) - - -@fixture -def mock_database_2() -> Database: - return Database( - name="database2", - description="completely different db", - region="us-east-1", - data_size_full=0, - data_size_compressed=0, - create_time=datetime.now().isoformat(), - create_actor="", - _status="", - _attached_engine_names="-", - _errors="", - _service=None, - ) - - -empty_response = { - "meta": [], - "data": [], - "rows": 0, - "statistics": { - "elapsed": 39635.785423446, - "rows_read": 0, - "bytes_read": 0, - "time_before_execution": 0, - "time_to_execute": 0, - }, -} - - -@fixture -def create_databases_callback( - system_engine_no_db_query_url: str, mock_database -) -> Callable: +def attach_engine_to_db_callback(system_engine_no_db_query_url: str) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, @@ -252,78 +270,68 @@ def do_mock( @fixture -def databases_get_callback( - mock_database: Database, mock_database_2: Database -) -> Callable: - return get_objects_from_db_callback([mock_database, mock_database_2]) - - -@fixture -def databases_url(server: str, account_id: str) -> str: - return f"https://{server}" + ACCOUNT_DATABASES_URL.format(account_id=account_id) +def updated_engine_scale() -> int: + return 10 @fixture -def database_callback(database_url: str, mock_database) -> Callable: +def update_engine_callback( + system_engine_no_db_query_url: str, + mock_engine: Engine, + updated_engine_scale: int, +) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == database_url + assert request.url == system_engine_no_db_query_url + mock_engine.scale = updated_engine_scale return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json=empty_response, ) return do_mock @fixture -def database_not_found_callback(database_url: str) -> Callable: +def create_databases_callback( + system_engine_no_db_query_url: str, mock_database +) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == database_url + assert request.url == system_engine_no_db_query_url return Response( status_code=httpx.codes.OK, - json={}, + json=empty_response, ) return do_mock @fixture -def database_url(server: str, account_id: str, mock_database) -> str: - return f"https://{server}" + ACCOUNT_DATABASE_URL.format( - account_id=account_id, database_id=mock_database.database_id - ) +def databases_get_callback( + mock_database: Database, mock_database_2: Database +) -> Callable: + return get_objects_from_db_callback([mock_database, mock_database_2]) @fixture -def database_get_by_name_callback(database_get_by_name_url, mock_database) -> Callable: +def database_not_found_callback() -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == database_get_by_name_url return Response( status_code=httpx.codes.OK, - json={"database_id": {"database_id": mock_database.database_id}}, + json=empty_response, ) return do_mock -@fixture -def database_get_by_name_url(server: str, account_id: str, mock_database) -> str: - return ( - f"https://{server}" - + ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id) - + f"?database_name={mock_database.name}" - ) - - @fixture def database_update_callback() -> Callable: def do_mock( @@ -338,40 +346,6 @@ def do_mock( return do_mock -def get_objects_from_db_callback(objs: List[dataclass]) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - fieldname = lambda f: (f.metadata or {}).get("db_name", f.name) - types = { - "int": _InternalType.Long.value, - "str": _InternalType.Text.value, - "datetime": _InternalType.Text.value, # we receive datetime as text from db - } - dc_fields = [f for f in fields(objs[0]) if f.name != "_service"] - query_response = { - "meta": [{"name": fieldname(f), "type": types[f.type]} for f in dc_fields], - "data": [[getattr(obj, f.name) for f in dc_fields] for obj in objs], - "rows": len(objs), - "statistics": { - "elapsed": 0.116907717, - "rows_read": 1, - "bytes_read": 61, - "time_before_execution": 0.012180623, - "time_to_execute": 0.104614307, - "scanned_bytes_cache": 0, - "scanned_bytes_storage": 0, - }, - } - return Response( - status_code=httpx.codes.OK, - json=query_response, - ) - - return do_mock - - @fixture def database_get_callback(mock_database) -> Callable: return get_objects_from_db_callback([mock_database]) diff --git a/tests/unit/service/test_database.py b/tests/unit/service/test_database.py index e5bba41801e..df7038440ec 100644 --- a/tests/unit/service/test_database.py +++ b/tests/unit/service/test_database.py @@ -3,6 +3,7 @@ from pytest_httpx import HTTPXMock from firebolt.model.database import Database +from firebolt.model.engine import Engine from firebolt.service.manager import ResourceManager @@ -13,6 +14,7 @@ def test_database_create( create_databases_callback: Callable, system_engine_no_db_query_url: str, mock_database: Database, + mock_engine: Engine, ): httpx_mock.add_callback( create_databases_callback, url=system_engine_no_db_query_url, method="POST" @@ -22,7 +24,10 @@ def test_database_create( ) database = resource_manager.databases.create( - name=mock_database.name, description=mock_database.description + name=mock_database.name, + region=mock_database.region, + attached_engines=[mock_engine], + description=mock_database.description, ) assert database == mock_database diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 656e0579752..ad45fdfca31 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -1,244 +1,73 @@ -from re import Pattern -from typing import Callable, List +from typing import Callable -from pydantic import ValidationError from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.common import Settings +from firebolt.model.database import Database from firebolt.model.engine import Engine -from firebolt.model.instance_type import InstanceType from firebolt.service.manager import ResourceManager -from firebolt.utils.exception import FireboltError, NoAttachedDatabaseError +from firebolt.utils.exception import ( + EngineNotFoundError, + NoAttachedDatabaseError, +) def test_engine_create( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - mock_system_engine_connection_flow: Callable, + system_engine_no_db_query_url: str, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) - - assert engine.name == engine_name - - -def test_engine_create_with_kwargs( - httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_id: str, - mock_engine_revision: EngineRevision, - mock_system_engine_connection_flow: Callable, -): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(region_callback, url=region_url) - # Setting to manager.engines.create defaults - mock_engine.key = None - mock_engine.description = "" - mock_engine.endpoint = None - # Testing kwargs - mock_engine.settings.minimum_logging_level = "ENGINE_SETTINGS_LOGGING_LEVEL_DEBUG" - mock_engine_revision.specification.proxy_version = "0.2.3" - engine_content = _EngineCreateRequest( - account_id=account_id, engine=mock_engine, engine_revision=mock_engine_revision - ) - httpx_mock.add_callback( - engine_callback, - url=engine_url, - method="POST", - match_content=engine_content.json(by_alias=True).encode("ascii"), - ) - - manager = ResourceManager(settings=settings) - engine_settings_kwargs = { - "minimum_logging_level": "ENGINE_SETTINGS_LOGGING_LEVEL_DEBUG" - } - revision_spec_kwargs = {"proxy_version": "0.2.3"} - engine = manager.engines.create( - name=engine_name, - engine_settings_kwargs=engine_settings_kwargs, - revision_spec_kwargs=revision_spec_kwargs, - ) - - assert engine.name == engine_name - -def test_engine_create_with_kwargs_fail( - httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - mock_system_engine_connection_flow: Callable, -): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + + engine = resource_manager.engines.create( + name=mock_engine.name, + region=mock_engine.region, + engine_type=mock_engine.type, + spec=mock_engine.spec, + scale=mock_engine.scale, + auto_stop=mock_engine.auto_stop, + warmup=mock_engine.warmup, ) - httpx_mock.add_callback(region_callback, url=region_url) - manager = ResourceManager(settings=settings) - revision_spec_kwargs = {"incorrect_kwarg": "val"} - with raises(ValidationError): - manager.engines.create( - name=engine_name, revision_spec_kwargs=revision_spec_kwargs - ) + assert engine == mock_engine - engine_settings_kwargs = {"incorrect_kwarg": "val"} - with raises(TypeError): - manager.engines.create( - name=engine_name, engine_settings_kwargs=engine_settings_kwargs - ) - -def test_engine_create_no_available_types( +def test_engine_not_found( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_empty_callback: Callable, - instance_type_region_2_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_url: str, - region_2: Region, - mock_system_engine_connection_flow: Callable, + resource_manager: ResourceManager, + get_engine_not_found_callback: Callable, + system_engine_no_db_query_url: str, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( - instance_type_empty_callback, url=instance_type_region_2_url + get_engine_not_found_callback, url=system_engine_no_db_query_url ) - manager = ResourceManager(settings=settings) - - with raises(FireboltError): - manager.engines.create(name=engine_name, region=region_2) + with raises(EngineNotFoundError): + resource_manager.engines.get("invalid name") def test_engine_no_attached_database( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_engine_callback: Callable, - account_engine_url: str, - database_callback: Callable, - database_url: str, - no_bindings_callback: Callable, - bindings_url: str, - mock_system_engine_connection_flow: Callable, -): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(no_bindings_callback, url=bindings_url) - - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) - - with raises(NoAttachedDatabaseError): - engine.start() - - -def test_engine_start_binding_to_missing_database( - httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, database_not_found_callback: Callable, - database_url: str, - bindings_callback: Callable, - bindings_url: str, - mock_system_engine_connection_flow: Callable, + system_engine_no_db_query_url: str, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + database_not_found_callback, url=system_engine_no_db_query_url ) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(database_not_found_callback, url=database_url) - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) + engine = resource_manager.engines.get("engine_name") with raises(NoAttachedDatabaseError): engine.start() @@ -246,166 +75,71 @@ def test_engine_start_binding_to_missing_database( def test_get_connection( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - db_name: str, - database_callback: Callable, - database_url: str, - bindings_callback: Callable, - bindings_url: str, - mock_connection_flow: Callable, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, + database_get_callback: Callable, + system_engine_no_db_query_url: str, + system_engine_query_url: str, + get_engine_url_callback: Callable, + mock_query: Callable, ): - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(bindings_callback, url=bindings_url) - - httpx_mock.add_callback(database_callback, url=database_url) - mock_connection_flow() + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(database_get_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) + mock_query() - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) + engine = resource_manager.engines.get("engine_name") with engine.get_connection() as connection: - assert connection + connection.cursor().execute("select 1") def test_attach_to_database( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, - create_databases_callback: Callable, - databases_url: str, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + mock_database: Database, + mock_engine: Engine, + get_engine_callback: Callable, database_get_callback: Callable, - database_get_url: str, - database_not_found_callback: Callable, - database_url: str, - db_name: str, - engine_name: str, - engine_callback: Callable, - engine_url: str, - create_binding_callback: Callable, - create_binding_url: str, - bindings_callback: Callable, - bindings_url: str, - mock_system_engine_connection_flow: Callable, + attach_engine_to_db_callback: Callable, + system_engine_no_db_query_url: str, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(database_get_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + attach_engine_to_db_callback, url=system_engine_no_db_query_url ) - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") - httpx_mock.add_callback(database_not_found_callback, url=database_url, method="GET") - # create engine - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - - # attach - httpx_mock.add_callback(database_get_callback, url=database_get_url) - httpx_mock.add_callback( - create_binding_callback, url=create_binding_url, method="POST" - ) + database = resource_manager.databases.get("database") + engine = resource_manager.engines.get("engine") - manager = ResourceManager(settings=settings) - database = manager.databases.create(name=db_name) + engine._service = resource_manager.engines - engine = manager.engines.create(name=engine_name) - engine.attach_to_database(database=database) + engine.attach_to_database(database) - assert engine.database == database + assert engine._database_name == database.name def test_engine_update( httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, + resource_manager: ResourceManager, mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_engine_url: str, - account_engine_callback: Callable, - mock_system_engine_connection_flow: Callable, + get_engine_callback: Callable, + update_engine_callback: Callable, + system_engine_no_db_query_url: str, + updated_engine_scale: int, ): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - account_engine_callback, url=account_engine_url, method="PATCH" - ) - manager = ResourceManager(settings=settings) - - mock_engine._service = manager.engines - engine = mock_engine.update( - name="new_engine_name", description="new engine description" - ) - - assert engine.name == "new_engine_name" - assert engine.description == "new engine description" - - -def test_engine_restart( - httpx_mock: HTTPXMock, - provider_callback: Callable, - provider_url: str, - settings: Settings, - mock_engine: Engine, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - account_engine_url: str, - bindings_callback: Callable, - bindings_url: str, - database_callback: Callable, - database_url: str, - mock_system_engine_connection_flow: Callable, -): - mock_system_engine_connection_flow() - httpx_mock.add_callback(provider_callback, url=provider_url) - - httpx_mock.add_callback( - engine_callback, url=f"{account_engine_url}:restart", method="POST" - ) - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(database_callback, url=database_url) - - manager = ResourceManager(settings=settings) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(update_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) - mock_engine._service = manager.engines - engine = mock_engine.restart(wait_for_startup=False) + mock_engine._service = resource_manager.engines + mock_engine.update(scale=updated_engine_scale) - assert engine.name == mock_engine.name + assert mock_engine.scale == updated_engine_scale diff --git a/tests/unit/service/test_instance_type.py b/tests/unit/service/test_instance_type.py index 8cfefb789f5..2c2e0610fba 100644 --- a/tests/unit/service/test_instance_type.py +++ b/tests/unit/service/test_instance_type.py @@ -2,26 +2,20 @@ from pytest_httpx import HTTPXMock -from firebolt.client.auth import Auth from firebolt.model.instance_type import InstanceType from firebolt.service.manager import ResourceManager def test_instance_type( httpx_mock: HTTPXMock, - auth: Auth, - account_name: str, - server: str, + resource_manager: ResourceManager, instance_type_callback: Callable, instance_type_url: str, mock_instance_types: List[InstanceType], cheapest_instance: InstanceType, mock_system_engine_connection_flow: Callable, ): - mock_system_engine_connection_flow() httpx_mock.add_callback(instance_type_callback, url=instance_type_url) - manager = ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) - - assert manager.instance_types.instance_types == mock_instance_types - assert manager.instance_types.cheapest_instance == cheapest_instance + assert resource_manager.instance_types.instance_types == mock_instance_types + assert resource_manager.instance_types.cheapest_instance == cheapest_instance From 6f86d3cef0eed0455abe9c6fddaca2d0e38904c3 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 22 Jun 2023 16:10:08 +0300 Subject: [PATCH 13/18] fix integration tests --- tests/integration/conftest.py | 11 -- .../resource_manager/test_database.py | 65 +++++++- .../resource_manager/test_engine.py | 147 ++++++------------ 3 files changed, 108 insertions(+), 115 deletions(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 05a86144f58..de382b8c3ac 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,7 +4,6 @@ from pytest import fixture from firebolt.client.auth import ClientCredentials -from firebolt.service.manager import Settings LOGGER = getLogger(__name__) @@ -23,16 +22,6 @@ def must_env(var_name: str) -> str: return environ[var_name] -@fixture(scope="session") -def rm_settings(api_endpoint, auth, account_name) -> Settings: - return Settings( - account_name=account_name, - server=api_endpoint, - auth=auth, - default_region="us-east-1", - ) - - @fixture(scope="session") def engine_name() -> str: return must_env(ENGINE_NAME_ENV) diff --git a/tests/integration/resource_manager/test_database.py b/tests/integration/resource_manager/test_database.py index f1bbec2dc5e..08732925f5d 100644 --- a/tests/integration/resource_manager/test_database.py +++ b/tests/integration/resource_manager/test_database.py @@ -1,9 +1,11 @@ -from firebolt.common import Settings +from firebolt.client.auth import Auth from firebolt.service.manager import ResourceManager def test_database_get_default_engine( - rm_settings: Settings, + auth: Auth, + account_name: str, + api_endpoint: str, database_name: str, stopped_engine_name: str, engine_name: str, @@ -11,13 +13,64 @@ def test_database_get_default_engine( """ Checks that the default engine is either running or stopped engine """ - rm = ResourceManager(rm_settings) + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) - db = rm.databases.get_by_name(database_name) + db = rm.databases.get(database_name) - engine = db.get_default_engine() - assert engine is not None, "default engine is None, but shouldn't" + engine = db.get_attached_engines()[0] + assert engine is not None, "engine is None, but shouldn't be" assert engine.name in [ stopped_engine_name, engine_name, ], "Returned default engine name is neither of known engines" + + +def test_databases_get_many( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, + engine_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) + + # get all databases, at least one should be returned + databases = rm.databases.get_many() + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(name_contains=database_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(attached_engine_name_eq=engine_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(attached_engine_name_contains=engine_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + region = [db for db in databases if db.name == database_name][0].region + + # get all databases, with region_eq + databases = rm.databases.get_many(region_eq=region) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with all filters + databases = rm.databases.get_many( + name_contains=database_name, + attached_engine_name_eq=engine_name, + attached_engine_name_contains=engine_name, + region_eq=region, + ) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} diff --git a/tests/integration/resource_manager/test_engine.py b/tests/integration/resource_manager/test_engine.py index 62ae8d8d6d9..a0064cfde25 100644 --- a/tests/integration/resource_manager/test_engine.py +++ b/tests/integration/resource_manager/test_engine.py @@ -1,23 +1,24 @@ from collections import namedtuple -import pytest - -from firebolt.model.engine import Engine -from firebolt.service.manager import ResourceManager, Settings -from firebolt.service.types import ( - EngineStatusSummary, - EngineType, - WarmupMethod, -) +from firebolt.client.auth import Auth +from firebolt.service.manager import ResourceManager +from firebolt.service.types import EngineStatus def make_engine_name(database_name: str, suffix: str) -> str: return f"{database_name}_{suffix}" -@pytest.mark.skip(reason="manual test") -def test_create_start_stop_engine(database_name: str): - rm = ResourceManager() +# @pytest.mark.skip(reason="manual test") +def test_create_start_stop_engine( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "start_stop") engine = rm.engines.create(name=name) @@ -26,127 +27,77 @@ def test_create_start_stop_engine(database_name: str): database = rm.databases.create(name=name) assert database.name == name - engine.attach_to_database(database=database) + engine.attach_to_database(database) assert engine.database == database - engine = engine.start() - assert ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING - ) - - engine = engine.stop() - assert engine.current_status_summary in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED, - } - - -@pytest.mark.skip(reason="manual test") -def test_copy_engine(database_name): - rm = ResourceManager() - name = make_engine_name(database_name, "copy") - - engine = rm.engines.create(name=name) - assert engine.name == name - - engine.name = f"{engine.name}_copy" - engine_copy = rm.engines._send_create_engine( - engine=engine, - engine_revision=rm.engine_revisions.get_by_key(engine.latest_revision_key), - ) - assert engine_copy - - -def test_databases_get_many(rm_settings: Settings, database_name, engine_name): - rm = ResourceManager(rm_settings) - - # get all databases, at least one should be returned - databases = rm.databases.get_many() - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(name_contains=database_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(attached_engine_name_eq=engine_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(attached_engine_name_contains=engine_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} + engine.start() + assert engine.current_status == EngineStatus.RUNNING + engine.stop() + assert engine.current_status in {EngineStatus.STOPPING, EngineStatus.STOPPED} -def get_engine_params(rm: ResourceManager, engine: Engine): - engine_revision = rm.engine_revisions.get_by_key(engine.latest_revision_key) - instance_type = rm.instance_types.get_by_key( - engine_revision.specification.db_compute_instances_type_key - ) - - return { - "engine_type": engine.settings.preset, - "scale": engine_revision.specification.db_compute_instances_count, - "spec": instance_type.name, - "auto_stop": engine.settings.auto_stop_delay_duration, - "warmup": engine.settings.warm_up, - "description": engine.description, - } + engine.delete() + database.delete() ParamValue = namedtuple("ParamValue", "set expected") ENGINE_UPDATE_PARAMS = { - "engine_type": ParamValue( - EngineType.DATA_ANALYTICS, "ENGINE_SETTINGS_PRESET_DATA_ANALYTICS" - ), - "scale": ParamValue(23, 23), - "spec": ParamValue("B1", "B1"), - "auto_stop": ParamValue(123, "7380s"), - "warmup": ParamValue(WarmupMethod.PRELOAD_ALL_DATA, "ENGINE_SETTINGS_WARM_UP_ALL"), - "description": ParamValue("new db description", "new db description"), + # commented parameters are not available yet + # "scale": ParamValue(23, 23), + # "spec": ParamValue("B1", "B1"), + "auto_stop": ParamValue(123, 7380), + # "warmup": ParamValue(WarmupMethod.PRELOAD_ALL_DATA, WarmupMethod.PRELOAD_ALL_DATA), } -def test_engine_update_single_parameter(rm_settings: Settings, database_name: str): - rm = ResourceManager(rm_settings) +def test_engine_update_single_parameter( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "single_param") engine = rm.engines.create(name=name) - engine.attach_to_database(database=rm.databases.get_by_name(database_name)) + engine.attach_to_database(rm.databases.get(database_name)) assert engine.database.name == database_name for param, value in ENGINE_UPDATE_PARAMS.items(): engine.update(**{param: value.set}) - engine = rm.engines.get_by_name(name) - new_params = get_engine_params(rm, engine) - assert new_params[param] == value.expected + engine_new = rm.engines.get(name) + assert getattr(engine_new, param) == value.expected, f"Invalid {param} value" engine.delete() -def test_engine_update_multiple_parameters(rm_settings: Settings, database_name: str): - rm = ResourceManager(rm_settings) +def test_engine_update_multiple_parameters( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "multi_param") engine = rm.engines.create(name=name) - engine.attach_to_database(database=rm.databases.get_by_name(database_name)) + engine.attach_to_database(rm.databases.get(database_name)) assert engine.database.name == database_name engine.update( **dict({(param, value.set) for param, value in ENGINE_UPDATE_PARAMS.items()}) ) - engine = rm.engines.get_by_name(name) - new_params = get_engine_params(rm, engine) + engine_new = rm.engines.get(name) for param, value in ENGINE_UPDATE_PARAMS.items(): - assert new_params[param] == value.expected + assert getattr(engine_new, param) == value.expected, f"Invalid {param} value" engine.delete() From 962cfaa68f0fe0983610110a4bc9ed78634b20ab Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 27 Jun 2023 13:39:08 +0300 Subject: [PATCH 14/18] fix unit tests --- tests/unit/async_db/test_cursor.py | 5 ++--- tests/unit/db/test_cursor.py | 5 ++--- tests/unit/db/test_util.py | 3 +++ tests/unit/service/conftest.py | 1 + tests/unit/service/test_engine.py | 3 +++ 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 95cdb0b6b1c..b4d220fac7b 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -8,7 +8,6 @@ from firebolt.async_db import Cursor from firebolt.common._types import Column from firebolt.common.base_cursor import ColType, CursorState, QueryStatus -from firebolt.common.settings import Settings from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, @@ -192,9 +191,9 @@ async def test_cursor_execute( async def test_cursor_execute_error( httpx_mock: HTTPXMock, + server: str, query_url: str, get_engines_url: str, - settings: Settings, db_name: str, query_statistics: Dict[str, Any], cursor: Cursor, @@ -315,7 +314,7 @@ def http_error(*args, **kwargs): with raises(EngineNotRunningError) as excinfo: await query() assert cursor._state == CursorState.ERROR - assert settings.server in str(excinfo) + assert server in str(excinfo) # Engine does not exist httpx_mock.add_response( diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index d61bc23a531..c94ba052d74 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -5,7 +5,6 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.common.settings import Settings from firebolt.db import Cursor from firebolt.db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.utils.exception import ( @@ -188,7 +187,7 @@ def test_cursor_execute( def test_cursor_execute_error( httpx_mock: HTTPXMock, get_engines_url: str, - settings: Settings, + server: str, db_name: str, query_url: str, query_statistics: Dict[str, Any], @@ -309,7 +308,7 @@ def http_error(*args, **kwargs): with raises(EngineNotRunningError) as excinfo: query() assert cursor._state == CursorState.ERROR - assert settings.server in str(excinfo) + assert server in str(excinfo) # Engine does not exist httpx_mock.add_response( diff --git a/tests/unit/db/test_util.py b/tests/unit/db/test_util.py index 452c09e6430..3f053bcd6b7 100644 --- a/tests/unit/db/test_util.py +++ b/tests/unit/db/test_util.py @@ -50,6 +50,9 @@ def test_is_engine_running_system( ): # System engine is always running assert is_engine_running(system_connection, "dummy") == True + # We didn't resolve account id since since we run no query + # We need to skip the mocked endpoint + httpx_mock.reset(False) # We haven't used account id endpoint since we didn't run any query, ignoring it httpx_mock.reset(False) diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 8502d383806..0dcb74cd948 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -200,6 +200,7 @@ def do_mock( "InstanceType": _InternalType.Text.value, "EngineType": _InternalType.Text.value, "EngineStatus": _InternalType.Text.value, + "WarmupMethod": _InternalType.Text.value, } dc_fields = [f for f in fields(objs[0]) if f.name != "_service"] diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index ad45fdfca31..0097153c8bb 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -129,12 +129,15 @@ def test_attach_to_database( def test_engine_update( httpx_mock: HTTPXMock, resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, mock_engine: Engine, get_engine_callback: Callable, update_engine_callback: Callable, system_engine_no_db_query_url: str, updated_engine_scale: int, ): + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback(update_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) From f10b5b1f058b830648353e4aa903bb92c8e2dc77 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 27 Jun 2023 13:48:03 +0300 Subject: [PATCH 15/18] address comments --- src/firebolt/model/__init__.py | 1 + src/firebolt/model/database.py | 6 +++--- tests/integration/resource_manager/test_engine.py | 1 - 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/firebolt/model/__init__.py b/src/firebolt/model/__init__.py index 24cacdd58ae..b43a03c52af 100644 --- a/src/firebolt/model/__init__.py +++ b/src/firebolt/model/__init__.py @@ -13,6 +13,7 @@ class FireboltBaseModel: @classmethod def _get_field_overrides(cls) -> Dict[str, str]: + """Create a mapping of db field name to class name where they are different.""" return { f.metadata["db_name"]: f.name for f in fields(cls) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index c7cb4698ab8..d85b13cf5a3 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -69,9 +69,9 @@ def update(self, description: str) -> Database: return self for engine in self.get_attached_engines(): - if engine.current_status in { - EngineStatus.STARTING, - EngineStatus.STOPPING, + if engine.current_status not in { + EngineStatus.RUNNING, + EngineStatus.STOPPED, }: raise AttachedEngineInUseError(method_name="update") diff --git a/tests/integration/resource_manager/test_engine.py b/tests/integration/resource_manager/test_engine.py index a0064cfde25..f21bc4dd1f3 100644 --- a/tests/integration/resource_manager/test_engine.py +++ b/tests/integration/resource_manager/test_engine.py @@ -9,7 +9,6 @@ def make_engine_name(database_name: str, suffix: str) -> str: return f"{database_name}_{suffix}" -# @pytest.mark.skip(reason="manual test") def test_create_start_stop_engine( auth: Auth, account_name: str, From ac9c7c7d9bc7579da872312de27c4cecad77c9ad Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 3 Jul 2023 10:19:11 +0300 Subject: [PATCH 16/18] address more comments --- src/firebolt/model/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index d85b13cf5a3..823437c2947 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -89,7 +89,7 @@ def delete(self) -> None: """ for engine in self.get_attached_engines(): - if engine.current_status in { + if engine.current_status not in { EngineStatus.STARTING, EngineStatus.STOPPING, }: From 1efaa154c2b602ea4c389e6d4042566875c0c871 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 4 Jul 2023 15:20:33 +0300 Subject: [PATCH 17/18] remove pydantic references --- docs/requirements.txt | 3 +-- setup.cfg | 6 ------ src/firebolt/common/base_cursor.py | 21 +++++++++++++++++---- src/firebolt/utils/usage_tracker.py | 26 ++++++++++++++++++++------ tests/unit/utils/test_usage_tracker.py | 3 +-- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 716eb2f6c1a..279ab19b0ce 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ -pydantic[dotenv] httpx[http2] aiorwlock async-property @@ -6,4 +5,4 @@ readerwriterlock sqlparse appdirs appdirs-stubs -cryptography \ No newline at end of file +cryptography diff --git a/setup.cfg b/setup.cfg index fee2c304831..d14eaba9a18 100755 --- a/setup.cfg +++ b/setup.cfg @@ -29,7 +29,6 @@ install_requires = async-property>=0.2.1 cryptography>=3.4.0 httpx[http2]==0.24.0 - pydantic[dotenv]>=1.8.2 python-dateutil>=2.8.2 readerwriterlock>=1.0.9 sqlparse>=0.4.2 @@ -66,15 +65,10 @@ dev = firebolt = py.typed [mypy] -plugins = pydantic.mypy disallow_untyped_defs = True show_error_codes = True files = src/ -[pydantic-mypy] -warn_required_dynamic_aliases = True -warn_untyped_fields = True - [flake8] exclude = tests/* max-line-length = 88 diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index 7b278595a03..b6c690bfce3 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -1,13 +1,13 @@ from __future__ import annotations import logging +from dataclasses import dataclass, fields from enum import Enum from functools import wraps from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from httpx import Response -from pydantic import BaseModel from firebolt.common._types import ( ColType, @@ -51,7 +51,8 @@ class QueryStatus(Enum): EXECUTION_ERROR = 8 -class Statistics(BaseModel): +@dataclass +class Statistics: """ Class for query execution statistics. """ @@ -61,8 +62,20 @@ class Statistics(BaseModel): bytes_read: int time_before_execution: float time_to_execute: float - scanned_bytes_cache: Optional[float] - scanned_bytes_storage: Optional[float] + scanned_bytes_cache: Optional[float] = None + scanned_bytes_storage: Optional[float] = None + + def __post_init__(self) -> None: + for field in fields(self): + value = getattr(self, field.name) + _type = eval(field.type) # type: ignore + + # Unpack Optional + if hasattr(_type, "__args__"): + _type = _type.__args__[0] + if value is not None and not isinstance(value, _type): + # convert values to proper types + setattr(self, field.name, _type(value)) def check_not_closed(func: Callable) -> Callable: diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index c5e9f4f2509..d56a7c071a2 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -1,17 +1,17 @@ import inspect import logging +from dataclasses import dataclass from importlib import import_module from pathlib import Path from platform import python_version, release, system from sys import modules from typing import Dict, List, Optional, Tuple -from pydantic import BaseModel - from firebolt import __version__ -class ConnectorVersions(BaseModel): +@dataclass +class ConnectorVersions: """ Verify correct parameter types """ @@ -19,6 +19,20 @@ class ConnectorVersions(BaseModel): clients: List[Tuple[str, str]] drivers: List[Tuple[str, str]] + def __post_init__(self) -> None: + if any( + [(not isinstance(pair, tuple) or len(pair) != 2) for pair in self.clients] + ): + raise ValueError("Invalid clients value: expected tuples of length 2") + if any([not isinstance(item, str) for pair in self.clients for item in pair]): + raise ValueError("Invalid clients value: expected tuples of strings") + if any( + [(not isinstance(pair, tuple) or len(pair) != 2) for pair in self.drivers] + ): + raise ValueError("Invalid drivers value: expected tuples of length 2") + if any([not isinstance(item, str) for pair in self.drivers for item in pair]): + raise ValueError("Invalid drivers value: expected tuples of strings") + logger = logging.getLogger(__name__) @@ -169,8 +183,8 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st def get_user_agent_header( - user_drivers: Optional[List[Tuple[str, str]]] = [], - user_clients: Optional[List[Tuple[str, str]]] = [], + user_drivers: Optional[List[Tuple[str, str]]] = None, + user_clients: Optional[List[Tuple[str, str]]] = None, ) -> str: """ Return a user agent header with connector stack and system information. @@ -194,7 +208,7 @@ def get_user_agent_header( "Detected running with drivers: %s and clients %s ", str(drivers), str(clients) ) # Override auto-detected connectors with info provided manually - versions = ConnectorVersions(clients=user_clients, drivers=user_drivers) + versions = ConnectorVersions(clients=user_clients or [], drivers=user_drivers or []) for name, version in versions.clients: clients[name] = version for name, version in versions.drivers: diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py index 1ab2f72db0c..2949b49757e 100644 --- a/tests/unit/utils/test_usage_tracker.py +++ b/tests/unit/utils/test_usage_tracker.py @@ -1,7 +1,6 @@ from collections import namedtuple from unittest.mock import MagicMock, patch -from pydantic import ValidationError from pytest import mark, raises from firebolt.utils.usage_tracker import ( @@ -199,5 +198,5 @@ def test_user_agent(drivers, clients, expected_string): MagicMock(return_value=("1", "2", "Win", "ciso")), ) def test_incorrect_user_agent(drivers, clients): - with raises(ValidationError): + with raises(ValueError): get_user_agent_header(drivers, clients) From 2d20ca5a5706421f41b7709187d7aa937b2f5324 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Tue, 4 Jul 2023 15:47:45 +0300 Subject: [PATCH 18/18] fix missing dependency --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index d14eaba9a18..f1fd9b6e343 100755 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ install_requires = aiorwlock==1.1.0 appdirs>=1.4.4 appdirs-stubs>=0.1.0 + async-generator>=1.10 async-property>=0.2.1 cryptography>=3.4.0 httpx[http2]==0.24.0