diff --git a/docs/requirements.txt b/docs/requirements.txt index 716eb2f6c1a..279ab19b0ce 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ -pydantic[dotenv] httpx[http2] aiorwlock async-property @@ -6,4 +5,4 @@ readerwriterlock sqlparse appdirs appdirs-stubs -cryptography \ No newline at end of file +cryptography diff --git a/setup.cfg b/setup.cfg index fee2c304831..f1fd9b6e343 100755 --- a/setup.cfg +++ b/setup.cfg @@ -26,10 +26,10 @@ install_requires = aiorwlock==1.1.0 appdirs>=1.4.4 appdirs-stubs>=0.1.0 + async-generator>=1.10 async-property>=0.2.1 cryptography>=3.4.0 httpx[http2]==0.24.0 - pydantic[dotenv]>=1.8.2 python-dateutil>=2.8.2 readerwriterlock>=1.0.9 sqlparse>=0.4.2 @@ -66,15 +66,10 @@ dev = firebolt = py.typed [mypy] -plugins = pydantic.mypy disallow_untyped_defs = True show_error_codes = True files = src/ -[pydantic-mypy] -warn_required_dynamic_aliases = True -warn_untyped_fields = True - [flake8] exclude = tests/* max-line-length = 88 diff --git a/src/firebolt/common/base_cursor.py b/src/firebolt/common/base_cursor.py index 7b278595a03..b6c690bfce3 100644 --- a/src/firebolt/common/base_cursor.py +++ b/src/firebolt/common/base_cursor.py @@ -1,13 +1,13 @@ from __future__ import annotations import logging +from dataclasses import dataclass, fields from enum import Enum from functools import wraps from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from httpx import Response -from pydantic import BaseModel from firebolt.common._types import ( ColType, @@ -51,7 +51,8 @@ class QueryStatus(Enum): EXECUTION_ERROR = 8 -class Statistics(BaseModel): +@dataclass +class Statistics: """ Class for query execution statistics. """ @@ -61,8 +62,20 @@ class Statistics(BaseModel): bytes_read: int time_before_execution: float time_to_execute: float - scanned_bytes_cache: Optional[float] - scanned_bytes_storage: Optional[float] + scanned_bytes_cache: Optional[float] = None + scanned_bytes_storage: Optional[float] = None + + def __post_init__(self) -> None: + for field in fields(self): + value = getattr(self, field.name) + _type = eval(field.type) # type: ignore + + # Unpack Optional + if hasattr(_type, "__args__"): + _type = _type.__args__[0] + if value is not None and not isinstance(value, _type): + # convert values to proper types + setattr(self, field.name, _type(value)) def check_not_closed(func: Callable) -> Callable: diff --git a/src/firebolt/model/__init__.py b/src/firebolt/model/__init__.py index 83992de4b6e..b43a03c52af 100644 --- a/src/firebolt/model/__init__.py +++ b/src/firebolt/model/__init__.py @@ -1,24 +1,29 @@ import json -from typing import Any +from dataclasses import dataclass, field, fields +from typing import ClassVar, Dict, Optional, Type, TypeVar -from pydantic import BaseModel +from firebolt.service.base import BaseService +Model = TypeVar("Model", bound="FireboltBaseModel") -class FireboltBaseModel(BaseModel): - class Config: - allow_population_by_field_name = True - extra = "forbid" - def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict: - """ - Generate a dictionary representation of the service that contains serialized - primitive types, and is therefore JSON-ready. +@dataclass +class FireboltBaseModel: + _service: BaseService = field(repr=False, compare=False) - This could be replaced with something native once this issue is resolved: - https://github.com/samuelcolvin/pydantic/issues/1409 + @classmethod + def _get_field_overrides(cls) -> Dict[str, str]: + """Create a mapping of db field name to class name where they are different.""" + return { + f.metadata["db_name"]: f.name + for f in fields(cls) + if "db_name" in f.metadata + } - This function is intended to improve the compatibility with HTTPX, which - expects to take in a dictionary of primitives as input to the JSON parameter - of its request function. See: https://www.python-httpx.org/api/#helper-functions - """ - return json.loads(self.json(*args, **kwargs)) + @classmethod + def _from_dict( + cls: Type[Model], data: dict, service: Optional[BaseService] = None + ) -> Model: + data["_service"] = service + field_name_overrides = cls._get_field_overrides() + return cls(**{field_name_overrides.get(k, k): v for k, v in data.items()}) diff --git a/src/firebolt/model/binding.py b/src/firebolt/model/binding.py deleted file mode 100644 index 26407d44a78..00000000000 --- a/src/firebolt/model/binding.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel, Field - -from firebolt.model import FireboltBaseModel - - -class BindingKey(BaseModel): - account_id: str - database_id: str - engine_id: str - - -class Binding(FireboltBaseModel): - """A binding between an engine and a database.""" - - binding_key: BindingKey = Field(alias="id") - is_default_engine: bool = Field(alias="engine_is_default") - - # optional - current_status: Optional[str] - health_status: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - - @property - def database_id(self) -> str: - return self.binding_key.database_id - - @property - def engine_id(self) -> str: - return self.binding_key.engine_id diff --git a/src/firebolt/model/database.py b/src/firebolt/model/database.py index 9cc92264c28..823437c2947 100644 --- a/src/firebolt/model/database.py +++ b/src/firebolt/model/database.py @@ -1,34 +1,22 @@ from __future__ import annotations import logging +from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Sequence - -from pydantic import Field, PrivateAttr +from typing import TYPE_CHECKING, ClassVar, List from firebolt.model import FireboltBaseModel -from firebolt.model.region import RegionKey -from firebolt.service.types import EngineStatusSummary +from firebolt.service.types import EngineStatus from firebolt.utils.exception import AttachedEngineInUseError -from firebolt.utils.urls import ACCOUNT_DATABASE_URL if TYPE_CHECKING: - from firebolt.model.binding import Binding from firebolt.model.engine import Engine from firebolt.service.database import DatabaseService logger = logging.getLogger(__name__) -class DatabaseKey(FireboltBaseModel): - account_id: str - database_id: str - - -class FieldMask(FireboltBaseModel): - paths: Sequence[str] = Field(alias="paths") - - +@dataclass class Database(FireboltBaseModel): """ A Firebolt database. @@ -37,143 +25,75 @@ class Database(FireboltBaseModel): but otherwise are not configurable. """ + ALTER_SQL: ClassVar[str] = "ALTER DATABASE {} WITH DESCRIPTION = ?" + + DROP_SQL: ClassVar[str] = "DROP DATABASE {}" + # internal - _service: DatabaseService = PrivateAttr() + _service: DatabaseService = field(repr=False, compare=False) # required - name: str = Field(min_length=1, max_length=255, regex=r"^[0-9a-zA-Z_]+$") - compute_region_key: RegionKey = Field(alias="compute_region_id") - - # optional - database_key: Optional[DatabaseKey] = Field(None, alias="id") - description: Optional[str] = Field(None, max_length=255) - emoji: Optional[str] = Field(None, max_length=255) - current_status: Optional[str] - health_status: Optional[str] - data_size_full: Optional[int] - data_size_compressed: Optional[int] - is_system_database: Optional[bool] - storage_bucket_name: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - - @classmethod - def parse_obj_with_service( - cls, obj: Any, database_service: DatabaseService - ) -> Database: - database = cls.parse_obj(obj) - database._service = database_service - return database - - @property - def database_id(self) -> Optional[str]: - if self.database_key is None: - return None - return self.database_key.database_id + name: str = field(metadata={"db_name": "database_name"}) + description: str = field() + region: str = field() + _status: str = field(repr=False, metadata={"db_name": "status"}) + data_size_full: int = field() + data_size_compressed: int = field() + _attached_engine_names: str = field( + repr=False, metadata={"db_name": "attached_engines"}, compare=False + ) + create_time: datetime = field(metadata={"db_name": "created_on"}) + create_actor: str = field(metadata={"db_name": "created_by"}) + _errors: str = field(repr=False, metadata={"db_name": "errors"}) def get_attached_engines(self) -> List[Engine]: """Get a list of engines that are attached to this database.""" + return self._service.resource_manager.engines.get_many(database_name=self.name) - return self._service.resource_manager.bindings.get_engines_bound_to_database( # noqa: E501 - database=self - ) - - def attach_to_engine( - self, engine: Engine, is_default_engine: bool = False - ) -> Binding: + def attach_engine(self, engine: Engine) -> None: """ Attach an engine to this database. Args: engine: The engine to attach. - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. """ - - return self._service.resource_manager.bindings.create( - engine=engine, database=self, is_default_engine=is_default_engine - ) - - def delete(self) -> Database: - """ - Delete a database from Firebolt. - - Raises an error if there are any attached engines. - """ - - for engine in self.get_attached_engines(): - if engine.current_status_summary in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STARTING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING, - }: - raise AttachedEngineInUseError(method_name="delete") - - logger.info( - f"Deleting Database (database_id={self.database_id}, name={self.name})" - ) - response = self._service.client.delete( - url=ACCOUNT_DATABASE_URL.format( - account_id=self._service.account_id, database_id=self.database_id - ), - headers={"Content-type": "application/json"}, - ) - return Database.parse_obj_with_service( - response.json()["database"], self._service + return self._service.resource_manager.engines.attach_to_database( + engine.name, self.name ) def update(self, description: str) -> Database: """ Updates a database description. """ + if not description: + return self - class _DatabaseUpdateRequest(FireboltBaseModel): - """Helper model for sending Database creation requests.""" - - account_id: str - database: Database - database_id: str - update_mask: FieldMask + for engine in self.get_attached_engines(): + if engine.current_status not in { + EngineStatus.RUNNING, + EngineStatus.STOPPED, + }: + raise AttachedEngineInUseError(method_name="update") + sql = self.ALTER_SQL.format(self.name) + with self._service._connection.cursor() as c: + c.execute(sql, (description,)) self.description = description + return self - logger.info( - f"Updating Database (database_id={self.database_id}, " - f"name={self.name}, description={self.description})" - ) + def delete(self) -> None: + """ + Delete a database from Firebolt. - payload = _DatabaseUpdateRequest( - account_id=self._service.account_id, - database=self, - database_id=self.database_id, - update_mask=FieldMask(paths=["description"]), - ).jsonable_dict(by_alias=True) - - response = self._service.client.patch( - url=ACCOUNT_DATABASE_URL.format( - account_id=self._service.account_id, database_id=self.database_id - ), - headers={"Content-type": "application/json"}, - json=payload, - ) + Raises an error if there are any attached engines. + """ - return Database.parse_obj_with_service( - response.json()["database"], self._service - ) + for engine in self.get_attached_engines(): + if engine.current_status not in { + EngineStatus.STARTING, + EngineStatus.STOPPING, + }: + raise AttachedEngineInUseError(method_name="delete") - def get_default_engine(self) -> Optional[Engine]: - """ - Returns: default engine of the database, or None if default engine is missing - """ - rm = self._service.resource_manager - default_engines = [ - rm.engines.get(binding.engine_id) - for binding in rm.bindings.get_many(database_id=self.database_id) - if binding.is_default_engine - ] - - return None if len(default_engines) == 0 else default_engines[0] + with self._service._connection.cursor() as c: + c.execute(self.DROP_SQL.format(self.name)) diff --git a/src/firebolt/model/engine.py b/src/firebolt/model/engine.py index 04c9ef378d9..cffdcce7291 100644 --- a/src/firebolt/model/engine.py +++ b/src/firebolt/model/engine.py @@ -3,31 +3,26 @@ import functools import logging import time -from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence - -from pydantic import Field, PrivateAttr +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Tuple, + Union, +) from firebolt.db import Connection, connect from firebolt.model import FireboltBaseModel -from firebolt.model.binding import Binding from firebolt.model.database import Database -from firebolt.model.engine_revision import EngineRevision, EngineRevisionKey -from firebolt.model.region import RegionKey -from firebolt.service.types import ( - EngineStatus, - EngineStatusSummary, - EngineType, - WarmupMethod, -) -from firebolt.utils.exception import NoAttachedDatabaseError -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_RESTART_URL, - ACCOUNT_ENGINE_START_URL, - ACCOUNT_ENGINE_STOP_URL, - ACCOUNT_ENGINE_URL, +from firebolt.model.instance_type import InstanceType +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.exception import ( + DatabaseNotFoundError, + NoAttachedDatabaseError, ) -from firebolt.utils.util import prune_dict if TYPE_CHECKING: from firebolt.service.engine import EngineService @@ -35,58 +30,6 @@ logger = logging.getLogger(__name__) -class EngineKey(FireboltBaseModel): - account_id: str - engine_id: str - - -def wait(seconds: int, timeout_time: float, error_message: str, verbose: bool) -> None: - time.sleep(seconds) - if time.time() > timeout_time: - raise TimeoutError(error_message) - if verbose: - print(".", end="") - - -class EngineSettings(FireboltBaseModel): - """ - Engine settings. - - See also: :py:class:`EngineRevisionSpecification - ` - which also contains engine configuration. - """ - - preset: str - auto_stop_delay_duration: str = Field(regex=r"^[0-9]+[sm]$|^0$") - minimum_logging_level: str - is_read_only: bool - warm_up: str - - @classmethod - def default( - cls, - engine_type: EngineType = EngineType.GENERAL_PURPOSE, - auto_stop_delay_duration: str = "1200s", - warm_up: WarmupMethod = WarmupMethod.PRELOAD_INDEXES, - minimum_logging_level: str = "ENGINE_SETTINGS_LOGGING_LEVEL_INFO", - ) -> EngineSettings: - if engine_type == EngineType.GENERAL_PURPOSE: - preset = engine_type.GENERAL_PURPOSE.api_settings_preset_name # type: ignore # noqa: E501 - is_read_only = False - else: - preset = engine_type.DATA_ANALYTICS.api_settings_preset_name # type: ignore - is_read_only = True - - return cls( - preset=preset, - auto_stop_delay_duration=auto_stop_delay_duration, - minimum_logging_level=minimum_logging_level, - is_read_only=is_read_only, - warm_up=warm_up.api_name, - ) - - def check_attached_to_database(func: Callable) -> Callable: """(Decorator) Ensure the engine is attached to a database.""" @@ -99,90 +42,78 @@ def inner(self: Engine, *args: Any, **kwargs: Any) -> Any: return inner -class FieldMask(FireboltBaseModel): - paths: Sequence[str] = Field(alias="paths") - - +@dataclass class Engine(FireboltBaseModel): """ A Firebolt engine. Responsible for performing work (queries, ingestion). - - Engines are configured in :py:class:`Settings - ` - and in :py:class:`EngineRevisionSpecification - `. """ - # internal - _service: EngineService = PrivateAttr() - - # required - name: str = Field(min_length=1, max_length=255, regex=r"^[0-9a-zA-Z_]+$") - compute_region_key: RegionKey = Field(alias="compute_region_id") - settings: EngineSettings - - # optional - key: Optional[EngineKey] = Field(None, alias="id") - description: Optional[str] - emoji: Optional[str] - current_status: Optional[EngineStatus] - current_status_summary: Optional[EngineStatusSummary] - latest_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="latest_revision_id" - ) - endpoint: Optional[str] - endpoint_serving_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="endpoint_serving_revision_id" - ) - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - last_use_time: Optional[datetime] - desired_status: Optional[str] - health_status: Optional[str] - endpoint_desired_revision_key: Optional[EngineRevisionKey] = Field( - None, alias="endpoint_desired_revision_id" + START_SQL: ClassVar[str] = "START ENGINE {}" + STOP_SQL: ClassVar[str] = "STOP ENGINE {}" + ALTER_PREFIX_SQL: ClassVar[str] = "ALTER ENGINE {} SET " + ALTER_PARAMETER_NAMES: ClassVar[Tuple] = ( + "SCALE", + "SPEC", + "AUTO_STOP", + "RENAME_TO", + "WARMUP", ) - - @classmethod - def parse_obj_with_service(cls, obj: Any, engine_service: EngineService) -> Engine: - engine = cls.parse_obj(obj) - engine._service = engine_service - return engine - - @property - def engine_id(self) -> str: - if self.key is None: - raise ValueError("engine key is None") - return self.key.engine_id + DROP_SQL: ClassVar[str] = "DROP ENGINE {}" + + _service: EngineService = field(repr=False, compare=False) + + name: str = field(metadata={"db_name": "engine_name"}) + region: str = field() + spec: InstanceType = field() + scale: int = field() + current_status: EngineStatus = field(metadata={"db_name": "status"}) + _database_name: str = field(repr=False, metadata={"db_name": "attached_to"}) + version: str = field() + endpoint: str = field(metadata={"db_name": "url"}) + warmup: WarmupMethod = field() + auto_stop: int = field() + type: EngineType = field() + provisioning: str = field() + + def __post_init__(self) -> None: + if isinstance(self.spec, str) and self.spec: + # Resolve engine specification + self.spec = self._service.resource_manager.instance_types.get(self.spec) + if isinstance(self.current_status, str) and self.current_status: + # Resolve engine status + self.current_status = EngineStatus(self.current_status) + if isinstance(self.warmup, str) and self.warmup: + # Resolve warmup method + self.warmup = WarmupMethod.from_display_name(self.warmup) + if isinstance(self.type, str) and self.type: + # Resolve engine type + self.type = EngineType.from_display_name(self.type) @property def database(self) -> Optional[Database]: - return self._service.resource_manager.bindings.get_database_bound_to_engine( - engine=self - ) - - def get_latest(self) -> Engine: - """Get an up-to-date instance of the engine from Firebolt.""" - return self._service.get(id_=self.engine_id) - - def attach_to_database( - self, database: Database, is_default_engine: bool = False - ) -> Binding: + if self._database_name: + try: + return self._service.resource_manager.databases.get(self._database_name) + except DatabaseNotFoundError: + pass + return None + + def refresh(self) -> None: + """Update attributes of the instance from Firebolt.""" + field_name_overrides = self._get_field_overrides() + for name, value in self._service._get_dict(self.name).items(): + setattr(self, field_name_overrides.get(name, name), value) + + self.__post_init__() + + def attach_to_database(self, database: Union[Database, str]) -> None: """ Attach this engine to a database. Args: database: Database to which the engine will be attached - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. """ - return self._service.resource_manager.bindings.create( - engine=self, database=database, is_default_engine=is_default_engine - ) + self._service.attach_to_database(self, database) @check_attached_to_database def get_connection(self) -> Connection: @@ -193,300 +124,124 @@ def get_connection(self) -> Connection: """ return connect( - database=self.database.name, # type: ignore # already checked by decorator + database=self._database_name, # type: ignore # already checked by decorator # we always have firebolt Auth as a client auth auth=self._service.client.auth, # type: ignore engine_name=self.name, - account_name=self._service.settings.account_name, - api_endpoint=self._service.settings.server, + account_name=self._service.resource_manager.account_name, + api_endpoint=self._service.resource_manager.api_endpoint, ) + def _wait_for_start_stop(self) -> None: + wait_timeout = 3600 + interval_seconds = 5 + timeout_time = time.time() + wait_timeout + while self.current_status in (EngineStatus.STOPPING, EngineStatus.STARTING): + logger.info( + f"Engine {self.name} is currently " + f"{self.current_status.value.lower()}, waiting" + ) + time.sleep(interval_seconds) + if time.time() > timeout_time: + raise TimeoutError( + f"Excedeed timeout of {wait_timeout}s waiting for " + f"an engine in {self.current_status.value.lower()} state" + ) + logger.info(".[!n]") + self.refresh() + @check_attached_to_database - def start( - self, - wait_for_startup: bool = True, - wait_timeout_seconds: int = 3600, - verbose: bool = False, - ) -> Engine: + def start(self) -> Engine: """ Start an engine. If it's already started, do nothing. - Args: - wait_for_startup: - If True, wait for startup to complete. - If False, return immediately after requesting startup. - wait_timeout_seconds: - Number of seconds to wait for startup to complete - before raising a TimeoutError - verbose: - If True, print dots periodically while waiting for engine start. - If False, do not print any dots. - Returns: - The updated engine from Firebolt. + The updated engine instance. """ - timeout_time = time.time() + wait_timeout_seconds - - engine = self.get_latest() - if ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING - ): - logger.info( - f"Engine (engine_id={self.engine_id}, name={self.name}) " - "is already running." - ) - return engine - - # wait for engine to stop first, if it's already stopping - # FUTURE: revisit logging and consider consolidating this if & the while below. - elif ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING - ): - logger.info( - f"Engine (engine_id={engine.engine_id}, name={engine.name}) " - "is in currently stopping, waiting for it to stop first." - ) - while ( - engine.current_status_summary - != EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED - ): - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( - "Engine " - f"(engine_id={engine.engine_id}, name={engine.name}) " - f"did not stop within {wait_timeout_seconds} seconds." - ), - verbose=True, - ) - engine = engine.get_latest() - logger.info( - f"Engine (engine_id={engine.engine_id}, name={engine.name}) stopped." + self.refresh() + self._wait_for_start_stop() + if self.current_status == EngineStatus.RUNNING: + logger.info(f"Engine {self.name} is already running.") + return self + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to start engine {self.name} because it's " + f"in {self.current_status.value.lower()} state" ) - engine = self._send_engine_request(ACCOUNT_ENGINE_START_URL) - logger.info( - f"Starting Engine (engine_id={engine.engine_id}, name={engine.name})" - ) - - # wait for engine to start - while wait_for_startup and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not start engine within {wait_timeout_seconds} seconds." - ), - verbose=verbose, - ) - previous_status_summary = engine.current_status_summary - engine = engine.get_latest() - if engine.current_status_summary != previous_status_summary: - logger.info( - "Engine status_summary=" - f"{getattr(engine.current_status_summary, 'name')}" - ) - - return engine + logger.info(f"Starting engine {self.name}") + with self._service._connection.cursor() as c: + c.execute(self.START_SQL.format(self.name)) + self.refresh() + return self @check_attached_to_database - def stop( - self, wait_for_stop: bool = False, wait_timeout_seconds: int = 3600 - ) -> Engine: - """Stop an Engine running on Firebolt.""" - timeout_time = time.time() + wait_timeout_seconds - - engine = self._send_engine_request(ACCOUNT_ENGINE_STOP_URL) - logger.info(f"Stopping Engine (engine_id={self.engine_id}, name={self.name})") - - while wait_for_stop and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not stop engine within {wait_timeout_seconds} seconds." - ), - verbose=False, - ) + def stop(self) -> Engine: + """Stop an engine. If it's already stopped, do nothing. - engine = engine.get_latest() - - return engine + Returns: + The updated engine instance. + """ + self.refresh() + self._wait_for_start_stop() + if self.current_status == EngineStatus.STOPPED: + logger.info(f"Engine {self.name} is already stopped.") + return self + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to stop engine {self.name} because it's " + f"in {self.current_status.value.lower()} state" + ) + logger.info(f"Stopping engine {self.name}") + with self._service._connection.cursor() as c: + c.execute(self.STOP_SQL.format(self.name)) + self.refresh() + return self def update( self, name: Optional[str] = None, - engine_type: Optional[EngineType] = None, scale: Optional[int] = None, - spec: Optional[str] = None, + spec: Union[InstanceType, str, None] = None, auto_stop: Optional[int] = None, warmup: Optional[WarmupMethod] = None, - description: Optional[str] = None, - use_spot: Optional[bool] = None, ) -> Engine: """ Updates the engine and returns an updated version of the engine. If all parameters are set to None, old engine parameter values remain. """ - class _EngineUpdateRequest(FireboltBaseModel): - """Helper model for sending engine update requests.""" - - account_id: str - desired_revision: Optional[EngineRevision] - engine: Engine - engine_id: str - update_mask: FieldMask - - # Update the engine parameters - self.name = name if name else self.name - self.description = description - - # Update engine settings - engine_settings_params = { - "engine_type": engine_type, - "auto_stop_delay_duration": f"{auto_stop * 60}s" - if auto_stop is not None - else None, - "warm_up": warmup, - } - self.settings = EngineSettings.default(**prune_dict(engine_settings_params)) - - # Update the engine desired_revision if needed - desired_revision = None - if (scale or spec or use_spot is not None) and self.latest_revision_key: - rm = self._service.resource_manager - desired_revision = rm.engine_revisions.get_by_key(self.latest_revision_key) - - if spec: - instance_type_key = rm.instance_types.get_by_name( - instance_type_name=spec, - region_name=rm.regions.regions_by_key[self.compute_region_key].name, - ).key - - desired_revision.specification.db_compute_instances_type_key = ( - instance_type_key - ) - desired_revision.specification.proxy_instances_type_key = ( - instance_type_key - ) - - if scale: - desired_revision.specification.db_compute_instances_count = scale - - if use_spot is not None: - desired_revision.specification.db_compute_instances_use_spot = use_spot - - update_mask_paths = list( - prune_dict( - { - "name": name, - "description": description, - "settings.auto_stop_delay_duration": auto_stop, - "settings.warm_up": warmup, - "settings.is_read_only": engine_type, - "settings.preset": engine_type, - } - ).keys() - ) - - # Send the update request - response = self._service.client.patch( - url=ACCOUNT_ENGINE_URL.format( - engine_id=self.engine_id, account_id=self._service.account_id - ), - headers={"Content-type": "application/json"}, - json=_EngineUpdateRequest( - account_id=self._service.account_id, - engine=self, - desired_revision=desired_revision, - engine_id=self.engine_id, - update_mask=FieldMask(paths=update_mask_paths), - ).jsonable_dict(by_alias=True), - ) - - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) - - @check_attached_to_database - def restart( - self, - wait_for_startup: bool = True, - wait_timeout_seconds: int = 3600, - ) -> Engine: - """ - Restart an engine. - - Args: - wait_for_startup: - If True, wait for startup to complete. - If False, return immediately after requesting startup. - wait_timeout_seconds: - Number of seconds to wait for startup to complete - before raising a TimeoutError. + if not any((name, scale, spec, auto_stop, warmup)): + # Nothing to be updated + return self - Returns: - The updated engine from Firebolt. - """ - timeout_time = time.time() + wait_timeout_seconds - - engine = self._send_engine_request(ACCOUNT_ENGINE_RESTART_URL) - logger.info(f"Stopping Engine (engine_id={self.engine_id}, name={self.name})") - - while wait_for_startup and engine.current_status_summary not in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_FAILED, - }: - wait( - seconds=5, - timeout_time=timeout_time, - error_message=( # noqa: E501 - f"Could not restart engine within {wait_timeout_seconds} seconds." - ), - verbose=False, + self.refresh() + self._wait_for_start_stop() + if self.current_status in (EngineStatus.DROPPING, EngineStatus.REPAIRING): + raise ValueError( + f"Unable to update engine {self.name} because it's " + f"in {self.current_status.value.lower()} state" ) - engine = engine.get_latest() - - return engine - - def delete(self) -> Engine: - """Delete an engine from Firebolt.""" - response = self._service.client.delete( - url=ACCOUNT_ENGINE_URL.format( - account_id=self._service.account_id, engine_id=self.engine_id - ), - ) - logger.info(f"Deleting Engine (engine_id={self.engine_id}, name={self.name})") - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) - - def _send_engine_request(self, url: str) -> Engine: - response = self._service.client.post( - url=url.format( - account_id=self._service.account_id, engine_id=self.engine_id - ) - ) - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self._service - ) - - -class _EngineCreateRequest(FireboltBaseModel): - """Helper model for sending engine create requests.""" - - account_id: str - engine: Engine - engine_revision: Optional[EngineRevision] + sql = self.ALTER_PREFIX_SQL.format(self.name) + parameters = [] + for param, value in zip( + self.ALTER_PARAMETER_NAMES, (scale, spec, auto_stop, name, warmup) + ): + if value: + sql += f"{param} = ? " + parameters.append(str(value)) + + with self._service._connection.cursor() as c: + c.execute(sql, parameters) + self.refresh() + return self + + def delete(self) -> None: + """Delete an engine.""" + self.refresh() + if self.current_status == EngineStatus.DROPPING: + return + with self._service._connection.cursor() as c: + c.execute(self.DROP_SQL.format(self.name)) diff --git a/src/firebolt/model/engine_revision.py b/src/firebolt/model/engine_revision.py deleted file mode 100644 index d3257f230c3..00000000000 --- a/src/firebolt/model/engine_revision.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic import Field, PositiveInt - -from firebolt.model import FireboltBaseModel -from firebolt.model.instance_type import InstanceTypeKey - - -class EngineRevisionKey(FireboltBaseModel): - account_id: str - engine_id: str - engine_revision_id: str - - -class EngineRevisionSpecification(FireboltBaseModel): - """ - An EngineRevision specification. - - Determines which instance types and how many of them its engine gets. - - See Also: :py:class:`Settings - `, - which also contains engine configuration. - """ - - db_compute_instances_type_key: InstanceTypeKey = Field( - alias="db_compute_instances_type_id" - ) - db_compute_instances_count: PositiveInt - db_compute_instances_use_spot: bool = False - db_version: str = "" - proxy_instances_type_key: InstanceTypeKey = Field(alias="proxy_instances_type_id") - proxy_instances_count: PositiveInt = 1 - proxy_version: str = "" - - -class EngineRevision(FireboltBaseModel): - """ - A Firebolt engine revision, - which contains a specification (instance types, counts). - - As engines are updated with new settings, revisions are created. - """ - - specification: EngineRevisionSpecification - - # optional - key: Optional[EngineRevisionKey] = Field(None, alias="id") - current_status: Optional[str] - create_time: Optional[datetime] - create_actor: Optional[str] - last_update_time: Optional[datetime] - last_update_actor: Optional[str] - desired_status: Optional[str] - health_status: Optional[str] diff --git a/src/firebolt/model/instance_type.py b/src/firebolt/model/instance_type.py index 24decd63006..4d3fdc36400 100644 --- a/src/firebolt/model/instance_type.py +++ b/src/firebolt/model/instance_type.py @@ -1,26 +1,21 @@ +from dataclasses import dataclass, field from datetime import datetime -from typing import Optional - -from pydantic import Field +from typing import Dict from firebolt.model import FireboltBaseModel -class InstanceTypeKey(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str - region_id: str - instance_type_id: str - - +@dataclass class InstanceType(FireboltBaseModel): - key: InstanceTypeKey = Field(alias="id") - name: str + _key: Dict = field(repr=False, metadata={"db_name": "id"}) + name: str = field() + is_spot_available: bool = field() + cpu_virtual_cores_count: int = field() + memory_size_bytes: int = field() + storage_size_bytes: int = field() + price_per_hour_cents: float = field() + create_time: datetime = field() + last_update_time: datetime = field() - # optional - is_spot_available: Optional[bool] - cpu_virtual_cores_count: Optional[int] - memory_size_bytes: Optional[int] - storage_size_bytes: Optional[int] - price_per_hour_cents: Optional[float] - create_time: Optional[datetime] - last_update_time: Optional[datetime] + def __str__(self) -> str: + return self.name diff --git a/src/firebolt/model/provider.py b/src/firebolt/model/provider.py deleted file mode 100644 index 242856ef218..00000000000 --- a/src/firebolt/model/provider.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from firebolt.model import FireboltBaseModel - - -class Provider(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str = Field(alias="id") - name: str - - # optional - create_time: Optional[datetime] - display_name: Optional[str] - last_update_time: Optional[datetime] diff --git a/src/firebolt/model/region.py b/src/firebolt/model/region.py deleted file mode 100644 index c3b476da26c..00000000000 --- a/src/firebolt/model/region.py +++ /dev/null @@ -1,21 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import Field - -from firebolt.model import FireboltBaseModel - - -class RegionKey(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str - region_id: str - - -class Region(FireboltBaseModel): - key: RegionKey = Field(alias="id") - name: str - - # optional - display_name: Optional[str] - create_time: Optional[datetime] - last_update_time: Optional[datetime] diff --git a/src/firebolt/service/base.py b/src/firebolt/service/base.py index d6617a43c34..bb892d3412d 100644 --- a/src/firebolt/service/base.py +++ b/src/firebolt/service/base.py @@ -1,20 +1,24 @@ +from typing import TYPE_CHECKING + from firebolt.client import Client -from firebolt.common import Settings -from firebolt.service.manager import ResourceManager +from firebolt.db import Connection + +if TYPE_CHECKING: + from firebolt.service.manager import ResourceManager class BaseService: - def __init__(self, resource_manager: ResourceManager): + def __init__(self, resource_manager: "ResourceManager"): self.resource_manager = resource_manager @property def client(self) -> Client: - return self.resource_manager.client + return self.resource_manager._client @property def account_id(self) -> str: return self.resource_manager.account_id @property - def settings(self) -> Settings: - return self.resource_manager.settings + def _connection(self) -> Connection: + return self.resource_manager._connection diff --git a/src/firebolt/service/binding.py b/src/firebolt/service/binding.py deleted file mode 100644 index a12aff81e4e..00000000000 --- a/src/firebolt/service/binding.py +++ /dev/null @@ -1,139 +0,0 @@ -import logging -from typing import List, Optional - -from firebolt.model.binding import Binding, BindingKey -from firebolt.model.database import Database -from firebolt.model.engine import Engine -from firebolt.service.base import BaseService -from firebolt.utils.exception import AlreadyBoundError -from firebolt.utils.urls import ( - ACCOUNT_BINDINGS_URL, - ACCOUNT_DATABASE_BINDING_URL, -) -from firebolt.utils.util import prune_dict - -logger = logging.getLogger(__name__) - - -class BindingService(BaseService): - def get_by_key(self, binding_key: BindingKey) -> Binding: - """Get a binding by its BindingKey""" - response = self.client.get( - url=ACCOUNT_DATABASE_BINDING_URL.format( - account_id=binding_key.account_id, - database_id=binding_key.database_id, - engine_id=binding_key.engine_id, - ) - ) - binding: dict = response.json()["binding"] - return Binding.parse_obj(binding) - - def get_many( - self, - database_id: Optional[str] = None, - engine_id: Optional[str] = None, - is_system_database: Optional[bool] = None, - ) -> List[Binding]: - """ - List bindings on Firebolt, optionally filtering by database and engine. - - Args: - database_id: - Return bindings matching the database_id. - If None, match any databases. - engine_id: - Return bindings matching the engine_id. - If None, match any engines. - is_system_database: - If True, return only system databases. - If False, return only non-system databases. - If None, do not filter on this parameter. - - Returns: - List of bindings matching the filter parameters - """ - - response = self.client.get( - url=ACCOUNT_BINDINGS_URL.format(account_id=self.account_id), - params=prune_dict( - { - "page.first": 5000, # FUTURE: pagination support w/ generator - "filter.id_database_id_eq": database_id, - "filter.id_engine_id_eq": engine_id, - "filter.is_system_database_eq": is_system_database, - } - ), - ) - return [Binding.parse_obj(i["node"]) for i in response.json()["edges"]] - - def get_database_bound_to_engine(self, engine: Engine) -> Optional[Database]: - """Get the database to which an engine is bound, if any.""" - try: - binding = self.get_many(engine_id=engine.engine_id)[0] - except IndexError: - return None - try: - return self.resource_manager.databases.get(id_=binding.database_id) - except (KeyError, IndexError): - return None - - def get_engines_bound_to_database(self, database: Database) -> List[Engine]: - """Get a list of engines that are bound to a database.""" - - bindings = self.get_many(database_id=database.database_id) - if not bindings: - return [] - return self.resource_manager.engines.get_by_ids( - ids=[b.engine_id for b in bindings] - ) - - def create( - self, engine: Engine, database: Database, is_default_engine: bool - ) -> Binding: - """ - Create a new binding between an engine and a database. - - Args: - engine: Engine to bind. - database: Database to bind. - is_default_engine: - Whether this engine should be used as default for this database. - Only one engine can be set as default for a single database. - This will overwrite any existing default. - - Returns: - New binding between the engine and database. - """ - - existing_database = self.get_database_bound_to_engine(engine=engine) - if existing_database is not None: - raise AlreadyBoundError( - f"The engine {engine.name} is already bound " - f"to {existing_database.name}!" - ) - - logger.info( - f"Attaching Engine (engine_id={engine.engine_id}, name={engine.name}) " - f"to Database (database_id={database.database_id}, " - f"name={database.name})" - ) - binding = Binding( - binding_key=BindingKey( - account_id=self.account_id, - database_id=database.database_id, - engine_id=engine.engine_id, - ), - is_default_engine=is_default_engine, - ) - - response = self.client.post( - url=ACCOUNT_DATABASE_BINDING_URL.format( - account_id=self.account_id, - database_id=database.database_id, - engine_id=engine.engine_id, - ), - json=binding.jsonable_dict( - by_alias=True, include={"binding_key": ..., "is_default_engine": ...} - ), - ) - return Binding.parse_obj(response.json()["binding"]) diff --git a/src/firebolt/service/database.py b/src/firebolt/service/database.py index 71d58450017..25f32a3aa9c 100644 --- a/src/firebolt/service/database.py +++ b/src/firebolt/service/database.py @@ -1,53 +1,59 @@ import logging from typing import List, Optional, Union -from firebolt.model import FireboltBaseModel from firebolt.model.database import Database +from firebolt.model.engine import Engine from firebolt.service.base import BaseService -from firebolt.service.types import DatabaseOrder -from firebolt.utils.urls import ( - ACCOUNT_DATABASE_BY_NAME_URL, - ACCOUNT_DATABASE_URL, - ACCOUNT_DATABASES_URL, -) -from firebolt.utils.util import prune_dict +from firebolt.utils.exception import DatabaseNotFoundError logger = logging.getLogger(__name__) class DatabaseService(BaseService): - def get(self, id_: str) -> Database: - """Get a Database from Firebolt by its ID.""" - - response = self.client.get( - url=ACCOUNT_DATABASE_URL.format(account_id=self.account_id, database_id=id_) - ) - return Database.parse_obj_with_service( - obj=response.json()["database"], database_service=self - ) - - def get_by_name(self, name: str) -> Database: - """Get a database from Firebolt by its name.""" - - database_id = self.get_id_by_name(name=name) - return self.get(id_=database_id) - - def get_id_by_name(self, name: str) -> str: - """Get a database ID from Firebolt by its name.""" - - response = self.client.get( - url=ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=self.account_id), - params={"database_name": name}, - ) - database_id = response.json()["database_id"]["database_id"] - return database_id + DB_FIELDS = ( + "database_name", + "description", + "region", + "status", + "data_size_full", + "data_size_compressed", + "attached_engines", + "created_on", + "created_by", + "errors", + ) + GET_SQL = f"SELECT {', '.join(DB_FIELDS)} FROM information_schema.databases" + GET_BY_NAME_SQL = GET_SQL + " WHERE database_name=?" + GET_WHERE_SQL = " WHERE " + + CREATE_PREFIX_SQL = "CREATE DATABASE {}{}" + CREATE_WITH_SQL = " WITH " + IF_NOT_EXISTS_SQL = "IF NOT EXISTS " + CREATE_PARAMETER_NAMES = ( + "REGION", + "ATTACHED_ENGINES", + "DESCRIPTION", + ) + + def _get_dict(self, name: str) -> dict: + with self._connection.cursor() as c: + count = c.execute(self.GET_BY_NAME_SQL, (name,)) + if count == 0: + raise DatabaseNotFoundError(name) + return { + column.name: value for column, value in zip(c.description, c.fetchone()) + } + + def get(self, name: str) -> Database: + """Get a Database from Firebolt by its name.""" + return Database._from_dict(self._get_dict(name), self) def get_many( self, name_contains: Optional[str] = None, attached_engine_name_eq: Optional[str] = None, attached_engine_name_contains: Optional[str] = None, - order_by: Optional[Union[str, DatabaseOrder]] = None, + region_eq: Optional[str] = None, ) -> List[Database]: """ Get a list of databases on Firebolt. @@ -57,71 +63,93 @@ def get_many( attached_engine_name_eq: Filter for databases by an exact engine name attached_engine_name_contains: Filter for databases by engines with a name containing this substring - order_by: Method by which to order the results. - See :py:class:`firebolt.service.types.DatabaseOrder` + region_eq: Filter for database by region Returns: A list of databases matching the filters """ - - if isinstance(order_by, str): - order_by = DatabaseOrder[order_by].name - - params = { - "page.first": "1000", - "order_by": order_by, - "filter.name_contains": name_contains, - "filter.attached_engine_name_eq": attached_engine_name_eq, - "filter.attached_engine_name_contains": attached_engine_name_contains, - } - - response = self.client.get( - url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), - params=prune_dict(params), - ) - - return [ - Database.parse_obj_with_service(obj=d["node"], database_service=self) - for d in response.json()["edges"] - ] + sql = self.GET_SQL + parameters = [] + if any( + ( + name_contains, + attached_engine_name_eq, + attached_engine_name_contains, + region_eq, + ) + ): + condition = [] + if name_contains: + condition.append("database_name like ?") + parameters.append(f"%{name_contains}%") + if attached_engine_name_eq: + condition.append( + "any_match(eng -> split_part(eng, ' ', 1) = ?," + " split(',', attached_engines))" + ) + parameters.append(attached_engine_name_eq) + if attached_engine_name_contains: + condition.append( + "any_match(eng -> split_part(eng, ' ', 1) like ?," + " split(',', attached_engines))" + ) + parameters.append(f"%{attached_engine_name_contains}%") + if region_eq: + condition.append("region = ?") + parameters.append(str(region_eq)) + sql += self.GET_WHERE_SQL + " AND ".join(condition) + + with self._connection.cursor() as c: + c.execute(sql, parameters) + dicts = [ + {column.name: value for column, value in zip(c.description, row)} + for row in c.fetchall() + ] + return [Database._from_dict(_dict, self) for _dict in dicts] def create( - self, name: str, region: Optional[str] = None, description: Optional[str] = None + self, + name: str, + region: Optional[str] = None, + attached_engines: Union[List[str], List[Engine], None] = None, + description: Optional[str] = None, + fail_if_exists: bool = True, ) -> Database: """ Create a new Database on Firebolt. Args: name: Name of the database + attached_engines: List of engines to attach to the database region: Region name in which to create the database + fail_if_exists: Fail is a database with provided name already exists Returns: The newly created database """ - class _DatabaseCreateRequest(FireboltBaseModel): - """Helper model for sending database creation requests.""" + logger.info(f"Creating database {name}") - account_id: str - database: Database - - if region is None: - region_key = self.resource_manager.regions.default_region.key - else: - region_key = self.resource_manager.regions.get_by_name(name=region).key - database = Database( - name=name, compute_region_key=region_key, description=description - ) - - logger.info(f"Creating Database (name={name})") - response = self.client.post( - url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), - headers={"Content-type": "application/json"}, - json=_DatabaseCreateRequest( - account_id=self.account_id, - database=database, - ).jsonable_dict(by_alias=True), - ) - return Database.parse_obj_with_service( - obj=response.json()["database"], database_service=self + sql = self.CREATE_PREFIX_SQL.format( + ("" if fail_if_exists else self.IF_NOT_EXISTS_SQL), name ) + parameters = [] + if any((region, attached_engines, description)): + sql += self.CREATE_WITH_SQL + for param, value in zip( + self.CREATE_PARAMETER_NAMES, + (region, attached_engines, description), + ): + if value: + sql += f"{param} = ? " + # Convert list of engines to a list of their names + if ( + isinstance(value, list) + and len(value) > 0 + and isinstance(value[0], Engine) + ): + value = [eng.name for eng in value] # type: ignore + parameters.append(value) + with self._connection.cursor() as c: + c.execute(sql, parameters) + return self.get(name) diff --git a/src/firebolt/service/engine.py b/src/firebolt/service/engine.py index dbee674e88a..e34982e5024 100644 --- a/src/firebolt/service/engine.py +++ b/src/firebolt/service/engine.py @@ -1,69 +1,68 @@ from logging import getLogger -from typing import Any, Dict, List, Optional, Union - -from firebolt.model.engine import Engine, EngineSettings, _EngineCreateRequest -from firebolt.model.engine_revision import ( - EngineRevision, - EngineRevisionSpecification, -) -from firebolt.model.region import Region +from typing import List, Optional, Union + +from firebolt.model.engine import Database, Engine +from firebolt.model.instance_type import InstanceType from firebolt.service.base import BaseService -from firebolt.service.types import EngineOrder, EngineType, WarmupMethod -from firebolt.utils.exception import FireboltError -from firebolt.utils.urls import ( - ACCOUNT_ENGINE_ID_BY_NAME_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_LIST_ENGINES_URL, - ENGINES_BY_IDS_URL, -) -from firebolt.utils.util import prune_dict +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.exception import EngineNotFoundError logger = getLogger(__name__) class EngineService(BaseService): - def get(self, id_: str) -> Engine: - """Get an engine from Firebolt by its ID.""" - - response = self.client.get( - url=ACCOUNT_ENGINE_URL.format(account_id=self.account_id, engine_id=id_), - ) - engine_entry: dict = response.json()["engine"] - return Engine.parse_obj_with_service(obj=engine_entry, engine_service=self) - - def get_by_ids(self, ids: List[str]) -> List[Engine]: - """Get multiple engines from Firebolt by ID.""" - response = self.client.post( - url=ENGINES_BY_IDS_URL, - json={ - "engine_ids": [ - {"account_id": self.account_id, "engine_id": engine_id} - for engine_id in ids - ] - }, - ) - return [ - Engine.parse_obj_with_service(obj=e, engine_service=self) - for e in response.json()["engines"] - ] - - def get_by_name(self, name: str) -> Engine: + DB_FIELDS = ( + "engine_name", + "region", + "spec", + "scale", + "status", + "attached_to", + "version", + "url", + "warmup", + "auto_stop", + "type", + "provisioning", + ) + GET_SQL = f"SELECT {', '.join(DB_FIELDS)} FROM information_schema.engines" + GET_BY_NAME_SQL = GET_SQL + " WHERE engine_name=?" + GET_WHERE_SQL = " WHERE " + + CREATE_PREFIX_SQL = "CREATE ENGINE {}{}" + IF_NOT_EXISTS_SQL = "IF NOT EXISTS " + CREATE_WITH_SQL = " WITH " + CREATE_PARAMETER_NAMES = ( + "REGION", + "ENGINE_TYPE", + "SPEC", + "SCALE", + "AUTO_STOP", + "WARMUP", + ) + + ATTACH_TO_DB_SQL = "ATTACH ENGINE {} TO {}" + + def _get_dict(self, name: str) -> dict: + with self._connection.cursor() as c: + count = c.execute(self.GET_BY_NAME_SQL, (name,)) + if count == 0: + raise EngineNotFoundError(name) + return { + column.name: value for column, value in zip(c.description, c.fetchone()) + } + + def get(self, name: str) -> Engine: """Get an engine from Firebolt by its name.""" - - response = self.client.get( - url=ACCOUNT_ENGINE_ID_BY_NAME_URL.format(account_id=self.account_id), - params={"engine_name": name}, - ) - engine_id = response.json()["engine_id"]["engine_id"] - return self.get(id_=engine_id) + return Engine._from_dict(self._get_dict(name), self) def get_many( self, name_contains: Optional[str] = None, - current_status_eq: Optional[str] = None, - current_status_not_eq: Optional[str] = None, + current_status_eq: Union[str, EngineStatus, None] = None, + current_status_not_eq: Union[str, EngineStatus, None] = None, region_eq: Optional[str] = None, - order_by: Optional[Union[str, EngineOrder]] = None, + database_name: Optional[str] = None, ) -> List[Engine]: """ Get a list of engines on Firebolt. @@ -73,50 +72,57 @@ def get_many( current_status_eq: Filter for engines with this status current_status_not_eq: Filter for engines that do not have this status region_eq: Filter for engines by region - order_by: Method by which to order the results. See [EngineOrder] Returns: A list of engines matching the filters """ - - if isinstance(order_by, str): - order_by = EngineOrder[order_by].name - - if region_eq is not None: - region_eq = self.resource_manager.regions.get_by_name( - name=region_eq - ).key.region_id - - response = self.client.get( - url=ACCOUNT_LIST_ENGINES_URL.format(account_id=self.account_id), - params=prune_dict( - { - "page.first": 5000, # FUTURE: pagination support w/ generator - "filter.name_contains": name_contains, - "filter.current_status_eq": current_status_eq, - "filter.current_status_not_eq": current_status_not_eq, - "filter.compute_region_id_region_id_eq": region_eq, - "order_by": order_by, - } - ), - ) - return [ - Engine.parse_obj_with_service(obj=e["node"], engine_service=self) - for e in response.json()["edges"] - ] + sql = self.GET_SQL + parameters = [] + if any( + ( + name_contains, + current_status_eq, + current_status_not_eq, + region_eq, + database_name, + ) + ): + condition = [] + if name_contains: + condition.append("engine_name like ?") + parameters.append(f"%{name_contains}%") + if current_status_eq: + condition.append("status = ?") + parameters.append(str(current_status_eq)) + if current_status_not_eq: + condition.append("status != ?") + parameters.append(str(current_status_eq)) + if region_eq: + condition.append("region = ?") + parameters.append(region_eq) + if database_name: + condition.append("attached_to = ?") + parameters.append(database_name) + sql += self.GET_WHERE_SQL + " AND ".join(condition) + + with self._connection.cursor() as c: + c.execute(sql, parameters) + dicts = [ + {column.name: value for column, value in zip(c.description, row)} + for row in c.fetchall() + ] + return [Engine._from_dict(_dict, self) for _dict in dicts] def create( self, name: str, - region: Union[str, Region, None] = None, + region: Optional[str] = None, engine_type: Union[str, EngineType] = EngineType.GENERAL_PURPOSE, - scale: int = 2, - spec: Optional[str] = None, - auto_stop: int = 20, - warmup: Union[str, WarmupMethod] = WarmupMethod.PRELOAD_INDEXES, - description: str = "", - engine_settings_kwargs: Dict[str, Any] = {}, - revision_spec_kwargs: Dict[str, Any] = {}, + spec: Union[InstanceType, str, None] = None, + scale: Optional[int] = None, + auto_stop: Optional[int] = None, + warmup: Union[str, WarmupMethod, None] = None, + fail_if_exists: bool = True, ) -> Engine: """ Create a new engine. @@ -125,10 +131,10 @@ def create( name: An identifier that specifies the name of the engine region: The AWS region in which the engine runs engine_type: The engine type. GENERAL_PURPOSE or DATA_ANALYTICS - scale: The number of compute instances on the engine. - The scale can be any int from 1 to 128. spec: Firebolt instance type. If not set, will default to the cheapest instance. + scale: The number of compute instances on the engine. + The scale can be any int from 1 to 128. auto_stop: The amount of time (in minutes) after which the engine automatically stops warmup: The warmup method that should be used: @@ -139,85 +145,39 @@ def create( `PRELOAD_ALL_DATA` - Full data auto-load (both indexes and table data - full warmup) - description: A short description of the engine's purpose + fail_if_exists: Fail is an engine with provided name already exists Returns: Engine with the specified settings """ - logger.info(f"Creating Engine (name={name})") - - if isinstance(engine_type, str): - engine_type = EngineType[engine_type] - if isinstance(warmup, str): - warmup = WarmupMethod[warmup] - - if region is None: - region = self.resource_manager.regions.default_region - else: - if isinstance(region, str): - region = self.resource_manager.regions.get_by_name(name=region) - - engine = Engine( - name=name, - description=description, - compute_region_key=region.key, - settings=EngineSettings.default( - engine_type=engine_type, - auto_stop_delay_duration=f"{auto_stop * 60}s", - warm_up=warmup, - **engine_settings_kwargs, - ), - ) - - if spec: - instance_type_key = self.resource_manager.instance_types.get_by_name( - instance_type_name=spec, region_name=region.name - ).key - else: - instance_type = ( - self.resource_manager.instance_types.cheapest_instance_in_region(region) - ) - if not instance_type: - raise FireboltError( - f"No suitable default instances found in region {region}" - ) - instance_type_key = instance_type.key - - engine_revision = EngineRevision( - specification=EngineRevisionSpecification( - db_compute_instances_type_key=instance_type_key, - db_compute_instances_count=scale, - proxy_instances_type_key=instance_type_key, - **revision_spec_kwargs, - ) - ) - - return self._send_create_engine(engine=engine, engine_revision=engine_revision) - - def _send_create_engine( - self, engine: Engine, engine_revision: Optional[EngineRevision] = None - ) -> Engine: - """ - Create a new Engine on Firebolt from the local Engine object. - - Args: - engine: The engine to create - engine_revision: EngineRevision to use for configuring the engine + logger.info(f"Creating engine {name}") - Returns: - The newly created engine - """ - - response = self.client.post( - url=ACCOUNT_LIST_ENGINES_URL.format(account_id=self.account_id), - headers={"Content-type": "application/json"}, - json=_EngineCreateRequest( - account_id=self.account_id, - engine=engine, - engine_revision=engine_revision, - ).jsonable_dict(by_alias=True), - ) - return Engine.parse_obj_with_service( - obj=response.json()["engine"], engine_service=self + sql = self.CREATE_PREFIX_SQL.format( + ("" if fail_if_exists else self.IF_NOT_EXISTS_SQL), name ) + parameters = [] + if any((region, engine_type, spec, scale, auto_stop, warmup)): + sql += self.CREATE_WITH_SQL + for param, value in zip( + self.CREATE_PARAMETER_NAMES, + (region, engine_type, spec, scale, auto_stop, warmup), + ): + if value: + sql += f"{param} = ? " + parameters.append(str(value)) + with self._connection.cursor() as c: + c.execute(sql, parameters) + return self.get(name) + + def attach_to_database( + self, engine: Union[Engine, str], database: Union[Database, str] + ) -> None: + engine_name = engine.name if isinstance(engine, Engine) else engine + database_name = database.name if isinstance(database, Database) else database + with self._connection.cursor() as c: + c.execute(self.ATTACH_TO_DB_SQL.format(engine_name, database_name)) + if isinstance(engine, Engine): + engine._database_name = ( + database.name if isinstance(database, Database) else database + ) diff --git a/src/firebolt/service/engine_revision.py b/src/firebolt/service/engine_revision.py deleted file mode 100644 index b79964d62b1..00000000000 --- a/src/firebolt/service/engine_revision.py +++ /dev/null @@ -1,39 +0,0 @@ -from firebolt.model.engine_revision import EngineRevision, EngineRevisionKey -from firebolt.service.base import BaseService -from firebolt.utils.urls import ACCOUNT_ENGINE_REVISION_URL - - -class EngineRevisionService(BaseService): - def get_by_id(self, engine_id: str, engine_revision_id: str) -> EngineRevision: - """ - Get an EngineRevision from Firebolt by engine_id and engine_revision_id. - """ - - return self.get_by_key( - EngineRevisionKey( - account_id=self.account_id, - engine_id=engine_id, - engine_revision_id=engine_revision_id, - ) - ) - - def get_by_key(self, key: EngineRevisionKey) -> EngineRevision: - """ - Fetch an EngineRevision from Firebolt by its key. - - Args: - key: Key of the desired EngineRevision - - Returns: - The requested EngineRevision - """ - - response = self.client.get( - url=ACCOUNT_ENGINE_REVISION_URL.format( - account_id=key.account_id, - engine_id=key.engine_id, - revision_id=key.engine_revision_id, - ), - ) - engine_spec: dict = response.json()["engine_revision"] - return EngineRevision.parse_obj(engine_spec) diff --git a/src/firebolt/service/instance_type.py b/src/firebolt/service/instance_type.py index 47a3ea26d45..47353e88f6b 100644 --- a/src/firebolt/service/instance_type.py +++ b/src/firebolt/service/instance_type.py @@ -1,19 +1,12 @@ -from typing import Dict, List, NamedTuple, Optional +from typing import List, Optional -from firebolt.model.instance_type import InstanceType, InstanceTypeKey -from firebolt.model.region import Region +from firebolt.model.instance_type import InstanceType from firebolt.service.base import BaseService +from firebolt.utils.exception import InstanceTypeNotFoundError from firebolt.utils.urls import ACCOUNT_INSTANCE_TYPES_URL from firebolt.utils.util import cached_property -class InstanceTypeLookup(NamedTuple): - """Helper tuple for looking up instance types by names.""" - - region_name: str - instance_type_name: str - - class InstanceTypeService(BaseService): @cached_property def instance_types(self) -> List[InstanceType]: @@ -23,90 +16,41 @@ def instance_types(self) -> List[InstanceType]: url=ACCOUNT_INSTANCE_TYPES_URL.format(account_id=self.account_id), params={"page.first": 5000}, ) - return [InstanceType.parse_obj(i["node"]) for i in response.json()["edges"]] - - @cached_property - def instance_types_by_key(self) -> Dict[InstanceTypeKey, InstanceType]: - """Dict of {InstanceTypeKey to InstanceType}""" - return {i.key: i for i in self.instance_types} + # Only take one instance type with a specific name + instance_types, names = list(), set() + for it in [i["node"] for i in response.json()["edges"]]: + if it["name"] not in names: + names.add(it["name"]) + instance_types.append(InstanceType._from_dict(it, self)) + return instance_types @cached_property - def instance_types_by_name(self) -> Dict[InstanceTypeLookup, InstanceType]: - """Dict of {InstanceTypeLookup to InstanceType}""" - - return { - InstanceTypeLookup( - region_name=self.resource_manager.regions.get_by_id( - id_=i.key.region_id - ).name, - instance_type_name=i.name, - ): i - for i in self.instance_types - } - - def get_instance_types_per_region(self, region: Region) -> List[InstanceType]: - """List of instance types available on Firebolt in specified region.""" - - response = self.client.get( - url=ACCOUNT_INSTANCE_TYPES_URL.format(account_id=self.account_id), - params={"page.first": 5000, "filter.id_region_id_eq": region.key.region_id}, - ) - - instance_list = [ - InstanceType.parse_obj(i["node"]) for i in response.json()["edges"] - ] - - # Filter out instances without storage - return [ - i - for i in instance_list - if i.storage_size_bytes and i.storage_size_bytes != "0" - ] - - def cheapest_instance_in_region(self, region: Region) -> Optional[InstanceType]: + def cheapest_instance(self) -> Optional[InstanceType]: # Get only available instances in region - instance_list = self.get_instance_types_per_region(region) - - if not instance_list: + if not self.instance_types: return None - cheapest = min( - instance_list, + return min( + self.instance_types, key=lambda x: x.price_per_hour_cents if x.price_per_hour_cents else float("Inf"), ) - return cheapest - - def get_by_key(self, instance_type_key: InstanceTypeKey) -> InstanceType: - """Get an instance type by key.""" - - return self.instance_types_by_key[instance_type_key] - def get_by_name( - self, - instance_type_name: str, - region_name: Optional[str] = None, - ) -> InstanceType: + def get(self, name: str) -> InstanceType: """ Get an instance type by name. Args: - instance_type_name: Name of the instance (eg. "i3.4xlarge") - region_name: - Name of the AWS region from which to get the instance. - If not provided, use the default region name from the client. + name: Name of the instance (eg. "i3.4xlarge") Returns: - The requested instance type + The requested instance type or None if it wasn't found """ # Will raise an error if neither set - region_name = region_name or self.resource_manager.regions.default_region.name - return self.instance_types_by_name[ - InstanceTypeLookup( - region_name=region_name, - instance_type_name=instance_type_name, - ) - ] + its = [it for it in self.instance_types if it.name == name] + if len(its) == 0: + raise InstanceTypeNotFoundError(name) + return its[0] diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index c8d646bfb78..4c0ed39e7cc 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -1,14 +1,31 @@ +import logging from typing import Optional from httpx import Timeout -from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx +from firebolt.client import ( + DEFAULT_API_URL, + Auth, + Client, + log_request, + log_response, + raise_on_4xx_5xx, +) from firebolt.common import Settings -from firebolt.service.provider import get_provider_id +from firebolt.db import connect from firebolt.utils.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 60 * 2 +logger = logging.getLogger(__name__) + +SETTINGS_DEPRECATION_MESSAGE = """ +Using Settings objects for ResourceManager intialization is deprecated. +Please pass parameters directly +Example: + >>> rm = ResourceManager(auth=ClientCredentials(..), account_name="my_account", ..) +""" + class ResourceManager: """ @@ -16,47 +33,92 @@ class ResourceManager: - databases - engines - - bindings (the bindings between an engine and a database) - - engine revisions (versions of an engine) Also provides listings of: - - regions (AWS regions in which engines can run) - instance types (AWS instance types which engines can use) """ - def __init__(self, settings: Optional[Settings] = None): - self.settings = settings or Settings() - self.client = Client( - auth=self.settings.auth, - base_url=fix_url_schema(self.settings.server), - account_name=self.settings.account_name, - api_endpoint=self.settings.server, + __slots__ = ( + "account_name", + "account_id", + "api_endpoint", + "_client", + "_connection", + "regions", + "instance_types", + "_provider_id", + "databases", + "engines", + "engine_revisions", + "bindings", + ) + + def __init__( + self, + settings: Optional[Settings] = None, + auth: Optional[Auth] = None, + account_name: Optional[str] = None, + api_endpoint: str = DEFAULT_API_URL, + ): + if settings: + logger.warning(SETTINGS_DEPRECATION_MESSAGE) + if auth or account_name or (api_endpoint != DEFAULT_API_URL): + raise ValueError( + "Other ResourceManager parameters are not allowed " + "when Settings are provided" + ) + auth = settings.auth + account_name = settings.account_name + api_endpoint = settings.server + + for param, name in ( + (auth, "auth"), + (account_name, "account_name"), + ): + if not param: + raise ValueError(f"Missing {name} value") + + # type checks + assert auth is not None + assert account_name is not None + + self._client = Client( + auth=auth, + base_url=fix_url_schema(api_endpoint), + account_name=account_name, + api_endpoint=api_endpoint, timeout=Timeout(DEFAULT_TIMEOUT_SECONDS), event_hooks={ "request": [log_request], "response": [raise_on_4xx_5xx, log_response], }, ) - self.account_id = self.client.account_id + self._connection = connect( + auth=auth, + account_name=account_name, + api_endpoint=api_endpoint, + ) + self.account_name = account_name + self.api_endpoint = api_endpoint + self.account_id = self._client.account_id self._init_services() def _init_services(self) -> None: # avoid circular import - from firebolt.service.binding import BindingService from firebolt.service.database import DatabaseService from firebolt.service.engine import EngineService - from firebolt.service.engine_revision import EngineRevisionService from firebolt.service.instance_type import InstanceTypeService - from firebolt.service.region import RegionService # Cloud Platform Resources (AWS) - self.regions = RegionService(resource_manager=self) self.instance_types = InstanceTypeService(resource_manager=self) - self.provider_id = get_provider_id(client=self.client) # Firebolt Resources self.databases = DatabaseService(resource_manager=self) self.engines = EngineService(resource_manager=self) - self.engine_revisions = EngineRevisionService(resource_manager=self) - self.bindings = BindingService(resource_manager=self) + + def __del__(self) -> None: + if hasattr(self, "_client"): + self._client.close() + if hasattr(self, "_connection"): + self._connection.close() diff --git a/src/firebolt/service/provider.py b/src/firebolt/service/provider.py deleted file mode 100644 index 9dcba914cc1..00000000000 --- a/src/firebolt/service/provider.py +++ /dev/null @@ -1,10 +0,0 @@ -from firebolt.client import Client -from firebolt.model.provider import Provider -from firebolt.utils.urls import PROVIDERS_URL - - -def get_provider_id(client: Client) -> str: - """Get the AWS provider_id.""" - response = client.get(url=PROVIDERS_URL) - providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]] - return providers[0].provider_id diff --git a/src/firebolt/service/region.py b/src/firebolt/service/region.py deleted file mode 100644 index ff614d56288..00000000000 --- a/src/firebolt/service/region.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Dict, List - -from firebolt.model.region import Region, RegionKey -from firebolt.service.base import BaseService -from firebolt.service.manager import ResourceManager -from firebolt.utils.urls import REGIONS_URL -from firebolt.utils.util import cached_property - - -class RegionService(BaseService): - def __init__(self, resource_manager: ResourceManager): - """ - Service to manage AWS regions (us-east-1, etc) - - Args: - resource_manager: Resource manager to use - """ - - super().__init__(resource_manager=resource_manager) - - @cached_property - def regions(self) -> List[Region]: - """List of available AWS regions on Firebolt.""" - - response = self.client.get(url=REGIONS_URL, params={"page.first": 5000}) - return [Region.parse_obj(i["node"]) for i in response.json()["edges"]] - - @cached_property - def regions_by_name(self) -> Dict[str, Region]: - """Dict of {RegionLookup to Region}""" - - return {r.name: r for r in self.regions} - - @cached_property - def regions_by_key(self) -> Dict[RegionKey, Region]: - """Dict of {RegionKey to Region}""" - - return {r.key: r for r in self.regions} - - @cached_property - def default_region(self) -> Region: - """Default AWS region, could be provided from environment.""" - - if not self.settings.default_region: - raise ValueError( - "The environment variable FIREBOLT_DEFAULT_REGION must be set." - ) - return self.get_by_name(name=self.settings.default_region) - - def get_by_name(self, name: str) -> Region: - """Get an AWS region by its name (eg. us-east-1).""" - - return self.regions_by_name[name] - - def get_by_key(self, key: RegionKey) -> Region: - """Get an AWS region by its key.""" - - return self.regions_by_key[key] - - def get_by_id(self, id_: str) -> Region: - """Get an AWS region by region_id.""" - - return self.get_by_key( - RegionKey(provider_id=self.resource_manager.provider_id, region_id=id_) - ) diff --git a/src/firebolt/service/types.py b/src/firebolt/service/types.py index 0ccdaef3573..022bb6ca549 100644 --- a/src/firebolt/service/types.py +++ b/src/firebolt/service/types.py @@ -1,17 +1,19 @@ from enum import Enum -from types import DynamicClassAttribute class EngineType(Enum): GENERAL_PURPOSE = "GENERAL_PURPOSE" DATA_ANALYTICS = "DATA_ANALYTICS" - @DynamicClassAttribute - def api_settings_preset_name(self) -> str: + @classmethod + def from_display_name(cls, display_name: str) -> "EngineType": return { - EngineType.GENERAL_PURPOSE: "ENGINE_SETTINGS_PRESET_GENERAL_PURPOSE", - EngineType.DATA_ANALYTICS: "ENGINE_SETTINGS_PRESET_DATA_ANALYTICS", - }[self] + "General Purpose": cls.GENERAL_PURPOSE, + "Analytics": cls.DATA_ANALYTICS, + }[display_name] + + def __str__(self) -> str: + return self.value class WarmupMethod(Enum): @@ -19,205 +21,35 @@ class WarmupMethod(Enum): PRELOAD_INDEXES = "PRELOAD_INDEXES" PRELOAD_ALL_DATA = "PRELOAD_ALL_DATA" - @DynamicClassAttribute - def api_name(self) -> str: + @classmethod + def from_display_name(cls, display_name: str) -> "WarmupMethod": return { - WarmupMethod.MINIMAL: "ENGINE_SETTINGS_WARM_UP_MINIMAL", - WarmupMethod.PRELOAD_INDEXES: "ENGINE_SETTINGS_WARM_UP_INDEXES", - WarmupMethod.PRELOAD_ALL_DATA: "ENGINE_SETTINGS_WARM_UP_ALL", - }[self] + "Minimal": cls.MINIMAL, + "Indexes": cls.PRELOAD_INDEXES, + "All": cls.PRELOAD_ALL_DATA, + }[display_name] + + def __str__(self) -> str: + return self.value class EngineStatus(Enum): """ Detailed engine status. - See: https://api.dev.firebolt.io/devDocs#operation/coreV1GetEngine - """ - - ENGINE_STATUS_UNSPECIFIED = "ENGINE_STATUS_UNSPECIFIED" - """ Logical record is created, however, underlying infrastructure - is not initialized. - In other words, this means that engine is stopped.""" - - ENGINE_STATUS_CREATED = "ENGINE_STATUS_CREATED" - """Engine status was created.""" - - ENGINE_STATUS_PROVISIONING_PENDING = "ENGINE_STATUS_PROVISIONING_PENDING" - """ Engine initialization request was sent.""" - - ENGINE_STATUS_PROVISIONING_STARTED = "ENGINE_STATUS_PROVISIONING_STARTED" - """ Engine initialization request was received - and initialization process started.""" - - ENGINE_STATUS_PROVISIONING_FINISHED = "ENGINE_STATUS_PROVISIONING_FINISHED" - """ Engine initialization was finished successfully.""" - - ENGINE_STATUS_PROVISIONING_FAILED = "ENGINE_STATUS_PROVISIONING_FAILED" - """ Engine initialization failed due to error.""" - - ENGINE_STATUS_RUNNING_IDLE = "ENGINE_STATUS_RUNNING_IDLE" - """ Engine is initialized, - but there are no running or starting engine revisions.""" - - ENGINE_STATUS_RUNNING_REVISION_STARTING = "ENGINE_STATUS_RUNNING_REVISION_STARTING" - """ Engine is initialized, - there are no running engine revisions, but it's starting.""" - - ENGINE_STATUS_RUNNING_REVISION_STARTUP_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_STARTUP_FAILED" - ) - """ Engine is initialized; - initial revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISION_SERVING = "ENGINE_STATUS_RUNNING_REVISION_SERVING" - """ Engine is ready (serves an engine revision). """ - - ENGINE_STATUS_RUNNING_REVISION_CHANGING = "ENGINE_STATUS_RUNNING_REVISION_CHANGING" - """ Engine is ready (serves an engine revision); - zero-downtime replacement revision is starting.""" - - ENGINE_STATUS_RUNNING_REVISION_CHANGE_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_CHANGE_FAILED" - ) - """ Engine is ready (serves an engine revision); - replacement revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISION_RESTARTING = ( - "ENGINE_STATUS_RUNNING_REVISION_RESTARTING" - ) - """ Engine is initialized; - replacement of the revision with a downtime is in progress.""" - - ENGINE_STATUS_RUNNING_REVISION_RESTART_FAILED = ( - "ENGINE_STATUS_RUNNING_REVISION_RESTART_FAILED" - ) - """ Engine is initialized; - replacement revision failed to provision or start.""" - - ENGINE_STATUS_RUNNING_REVISIONS_TERMINATING = ( - "ENGINE_STATUS_RUNNING_REVISIONS_TERMINATING" - ) - """ Engine is initialized; - all child revisions are being terminated.""" - - # Engine termination request was sent. - ENGINE_STATUS_TERMINATION_PENDING = "ENGINE_STATUS_TERMINATION_PENDING" - """ Engine termination request was sent.""" - - ENGINE_STATUS_TERMINATION_ST = "ENGINE_STATUS_TERMINATION_STARTED" - """ Engine termination started.""" - - ENGINE_STATUS_TERMINATION_FIN = "ENGINE_STATUS_TERMINATION_FINISHED" - """ Engine termination finished.""" - - ENGINE_STATUS_TERMINATION_F = "ENGINE_STATUS_TERMINATION_FAILED" - """ Engine termination failed.""" + See: https://docs.firebolt.io/working-with-engines/understanding-engine-fundamentals.html + """ # noqa - ENGINE_STATUS_DELETED = "ENGINE_STATUS_DELETED" - """ Engine is soft-deleted.""" + STARTING = "Starting" + STARTED = "Started" + RUNNING = "Running" + STOPPING = "Stopping" + STOPPED = "Stopped" + DROPPING = "Dropping" + REPAIRING = "Repairing" - -class EngineStatusSummary(Enum): - """ - Engine summary status. - - See: https://api.dev.firebolt.io/devDocs#operation/coreV1GetEngine - """ - - ENGINE_STATUS_SUMMARY_UNSPECIFIED = "ENGINE_STATUS_SUMMARY_UNSPECIFIED" - """Status unspecified""" - - ENGINE_STATUS_SUMMARY_STOPPED = "ENGINE_STATUS_SUMMARY_STOPPED" - """ Fully stopped.""" - - ENGINE_STATUS_SUMMARY_STARTING = "ENGINE_STATUS_SUMMARY_STARTING" - """ Provisioning process is in progress; - creating cloud infra for this engine.""" - - ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING = ( - "ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING" - ) - """ Provisioning process is complete; - waiting for PackDB cluster to initialize and start.""" - - ENGINE_STATUS_SUMMARY_RUNNING = "ENGINE_STATUS_SUMMARY_RUNNING" - """ Fully started; - engine is ready to serve requests.""" - - ENGINE_STATUS_SUMMARY_UPGRADING = "ENGINE_STATUS_SUMMARY_UPGRADING" - """ Version of the PackDB is changing. - This is zero downtime operation that does not affect engine work. - This status is reserved for future use (not used fow now).""" - - ENGINE_STATUS_SUMMARY_RESTARTING = "ENGINE_STATUS_SUMMARY_RESTARTING" - """ Hard restart (full stop/start cycle) is in progress. - Underlying infrastructure is being recreated.""" - - ENGINE_STATUS_SUMMARY_RESTARTING_INITIALIZING = ( - "ENGINE_STATUS_SUMMARY_RESTARTING_INITIALIZING" - ) - """ Hard restart (full stop/start cycle) is in progress. - Underlying infrastructure is ready. Waiting for - PackDB cluster to initialize and start. - This status is logically the same as ENGINE_STATUS_SUMMARY_STARTING_INITIALIZING, - but used during restart cycle.""" - - ENGINE_STATUS_SUMMARY_REPAIRING = "ENGINE_STATUS_SUMMARY_REPAIRING" - """ Underlying infrastructure has issues and is being repaired. - Engine is still running, but it's not fully healthy and some queries may fail.""" - - ENGINE_STATUS_SUMMARY_STOPPING = "ENGINE_STATUS_SUMMARY_STOPPING" - """ Stop is in progress.""" - - ENGINE_STATUS_SUMMARY_DELETING = "ENGINE_STATUS_SUMMARY_DELETING" - """ Termination is in progress. - All infrastructure that belongs to this engine will be completely destroyed.""" - - ENGINE_STATUS_SUMMARY_DELETED = "ENGINE_STATUS_SUMMARY_DELETED" - """ Infrastructure is terminated, engine data is deleted.""" - - ENGINE_STATUS_SUMMARY_FAILED = "ENGINE_STATUS_SUMMARY_FAILED" - """ Failed to start or stop. - This status only indicates that there were issues during provisioning operations. - If engine enters this status, - all infrastructure should be stopped/terminated already.""" - - -class EngineOrder(Enum): - ENGINE_ORDER_UNSPECIFIED = "ENGINE_ORDER_UNSPECIFIED" - ENGINE_ORDER_NAME_ASC = "ENGINE_ORDER_NAME_ASC" - ENGINE_ORDER_NAME_DESC = "ENGINE_ORDER_NAME_DESC" - ENGINE_ORDER_COMPUTE_REGION_ID_ASC = "ENGINE_ORDER_COMPUTE_REGION_ID_ASC" - ENGINE_ORDER_COMPUTE_REGION_ID_DESC = "ENGINE_ORDER_COMPUTE_REGION_ID_DESC" - ENGINE_ORDER_CURRENT_STATUS_ASC = "ENGINE_ORDER_CURRENT_STATUS_ASC" - ENGINE_ORDER_CURRENT_STATUS_DESC = "ENGINE_ORDER_CURRENT_STATUS_DESC" - ENGINE_ORDER_CREATE_TIME_ASC = "ENGINE_ORDER_CREATE_TIME_ASC" - ENGINE_ORDER_CREATE_TIME_DESC = "ENGINE_ORDER_CREATE_TIME_DESC" - ENGINE_ORDER_CREATE_ACTOR_ASC = "ENGINE_ORDER_CREATE_ACTOR_ASC" - ENGINE_ORDER_CREATE_ACTOR_DESC = "ENGINE_ORDER_CREATE_ACTOR_DESC" - ENGINE_ORDER_LAST_UPDATE_TIME_ASC = "ENGINE_ORDER_LAST_UPDATE_TIME_ASC" - ENGINE_ORDER_LAST_UPDATE_TIME_DESC = "ENGINE_ORDER_LAST_UPDATE_TIME_DESC" - ENGINE_ORDER_LAST_UPDATE_ACTOR_ASC = "ENGINE_ORDER_LAST_UPDATE_ACTOR_ASC" - ENGINE_ORDER_LAST_UPDATE_ACTOR_DESC = "ENGINE_ORDER_LAST_UPDATE_ACTOR_DESC" - ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_CURRENT_STATUS_DESC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_COUNT_DESC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_ASC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_ASC" - ) - ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_DESC = ( - "ENGINE_ORDER_LATEST_REVISION_SPECIFICATION_DB_COMPUTE_INSTANCES_TYPE_ID_DESC" - ) + def __str__(self) -> str: + return self.value class DatabaseOrder(Enum): diff --git a/src/firebolt/utils/exception.py b/src/firebolt/utils/exception.py index a5e87212ecb..fee1d9f1044 100644 --- a/src/firebolt/utils/exception.py +++ b/src/firebolt/utils/exception.py @@ -13,6 +13,27 @@ def __init__(self, engine_name: str): super().__init__(f"Engine {engine_name} is not running") +class EngineNotFoundError(FireboltEngineError): + """Engine with provided name was not found.""" + + def __init__(self, engine_name: str): + super().__init__(f"Engine with name {engine_name} was not found") + + +class DatabaseNotFoundError(FireboltError): + """Database with provided name was not found.""" + + def __init__(self, database_name: str): + super().__init__(f"Database with name {database_name} was not found") + + +class InstanceTypeNotFoundError(FireboltError): + """Instance type with provided name was not found.""" + + def __init__(self, instance_type_name: str): + super().__init__(f"Instance type with name {instance_type_name} was not found") + + class NoAttachedDatabaseError(FireboltEngineError): """Engine that's being accessed is not running. diff --git a/src/firebolt/utils/usage_tracker.py b/src/firebolt/utils/usage_tracker.py index c5e9f4f2509..d56a7c071a2 100644 --- a/src/firebolt/utils/usage_tracker.py +++ b/src/firebolt/utils/usage_tracker.py @@ -1,17 +1,17 @@ import inspect import logging +from dataclasses import dataclass from importlib import import_module from pathlib import Path from platform import python_version, release, system from sys import modules from typing import Dict, List, Optional, Tuple -from pydantic import BaseModel - from firebolt import __version__ -class ConnectorVersions(BaseModel): +@dataclass +class ConnectorVersions: """ Verify correct parameter types """ @@ -19,6 +19,20 @@ class ConnectorVersions(BaseModel): clients: List[Tuple[str, str]] drivers: List[Tuple[str, str]] + def __post_init__(self) -> None: + if any( + [(not isinstance(pair, tuple) or len(pair) != 2) for pair in self.clients] + ): + raise ValueError("Invalid clients value: expected tuples of length 2") + if any([not isinstance(item, str) for pair in self.clients for item in pair]): + raise ValueError("Invalid clients value: expected tuples of strings") + if any( + [(not isinstance(pair, tuple) or len(pair) != 2) for pair in self.drivers] + ): + raise ValueError("Invalid drivers value: expected tuples of length 2") + if any([not isinstance(item, str) for pair in self.drivers for item in pair]): + raise ValueError("Invalid drivers value: expected tuples of strings") + logger = logging.getLogger(__name__) @@ -169,8 +183,8 @@ def format_as_user_agent(drivers: Dict[str, str], clients: Dict[str, str]) -> st def get_user_agent_header( - user_drivers: Optional[List[Tuple[str, str]]] = [], - user_clients: Optional[List[Tuple[str, str]]] = [], + user_drivers: Optional[List[Tuple[str, str]]] = None, + user_clients: Optional[List[Tuple[str, str]]] = None, ) -> str: """ Return a user agent header with connector stack and system information. @@ -194,7 +208,7 @@ def get_user_agent_header( "Detected running with drivers: %s and clients %s ", str(drivers), str(clients) ) # Override auto-detected connectors with info provided manually - versions = ConnectorVersions(clients=user_clients, drivers=user_drivers) + versions = ConnectorVersions(clients=user_clients or [], drivers=user_drivers or []) for name, version in versions.clients: clients[name] = version for name, version in versions.drivers: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 05a86144f58..de382b8c3ac 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,7 +4,6 @@ from pytest import fixture from firebolt.client.auth import ClientCredentials -from firebolt.service.manager import Settings LOGGER = getLogger(__name__) @@ -23,16 +22,6 @@ def must_env(var_name: str) -> str: return environ[var_name] -@fixture(scope="session") -def rm_settings(api_endpoint, auth, account_name) -> Settings: - return Settings( - account_name=account_name, - server=api_endpoint, - auth=auth, - default_region="us-east-1", - ) - - @fixture(scope="session") def engine_name() -> str: return must_env(ENGINE_NAME_ENV) diff --git a/tests/integration/resource_manager/test_database.py b/tests/integration/resource_manager/test_database.py index f1bbec2dc5e..08732925f5d 100644 --- a/tests/integration/resource_manager/test_database.py +++ b/tests/integration/resource_manager/test_database.py @@ -1,9 +1,11 @@ -from firebolt.common import Settings +from firebolt.client.auth import Auth from firebolt.service.manager import ResourceManager def test_database_get_default_engine( - rm_settings: Settings, + auth: Auth, + account_name: str, + api_endpoint: str, database_name: str, stopped_engine_name: str, engine_name: str, @@ -11,13 +13,64 @@ def test_database_get_default_engine( """ Checks that the default engine is either running or stopped engine """ - rm = ResourceManager(rm_settings) + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) - db = rm.databases.get_by_name(database_name) + db = rm.databases.get(database_name) - engine = db.get_default_engine() - assert engine is not None, "default engine is None, but shouldn't" + engine = db.get_attached_engines()[0] + assert engine is not None, "engine is None, but shouldn't be" assert engine.name in [ stopped_engine_name, engine_name, ], "Returned default engine name is neither of known engines" + + +def test_databases_get_many( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, + engine_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) + + # get all databases, at least one should be returned + databases = rm.databases.get_many() + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(name_contains=database_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(attached_engine_name_eq=engine_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with name_contains + databases = rm.databases.get_many(attached_engine_name_contains=engine_name) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + region = [db for db in databases if db.name == database_name][0].region + + # get all databases, with region_eq + databases = rm.databases.get_many(region_eq=region) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} + + # get all databases, with all filters + databases = rm.databases.get_many( + name_contains=database_name, + attached_engine_name_eq=engine_name, + attached_engine_name_contains=engine_name, + region_eq=region, + ) + assert len(databases) > 0 + assert database_name in {db.name for db in databases} diff --git a/tests/integration/resource_manager/test_engine.py b/tests/integration/resource_manager/test_engine.py index 62ae8d8d6d9..f21bc4dd1f3 100644 --- a/tests/integration/resource_manager/test_engine.py +++ b/tests/integration/resource_manager/test_engine.py @@ -1,23 +1,23 @@ from collections import namedtuple -import pytest - -from firebolt.model.engine import Engine -from firebolt.service.manager import ResourceManager, Settings -from firebolt.service.types import ( - EngineStatusSummary, - EngineType, - WarmupMethod, -) +from firebolt.client.auth import Auth +from firebolt.service.manager import ResourceManager +from firebolt.service.types import EngineStatus def make_engine_name(database_name: str, suffix: str) -> str: return f"{database_name}_{suffix}" -@pytest.mark.skip(reason="manual test") -def test_create_start_stop_engine(database_name: str): - rm = ResourceManager() +def test_create_start_stop_engine( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "start_stop") engine = rm.engines.create(name=name) @@ -26,127 +26,77 @@ def test_create_start_stop_engine(database_name: str): database = rm.databases.create(name=name) assert database.name == name - engine.attach_to_database(database=database) + engine.attach_to_database(database) assert engine.database == database - engine = engine.start() - assert ( - engine.current_status_summary - == EngineStatusSummary.ENGINE_STATUS_SUMMARY_RUNNING - ) - - engine = engine.stop() - assert engine.current_status_summary in { - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING, - EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPED, - } - - -@pytest.mark.skip(reason="manual test") -def test_copy_engine(database_name): - rm = ResourceManager() - name = make_engine_name(database_name, "copy") - - engine = rm.engines.create(name=name) - assert engine.name == name - - engine.name = f"{engine.name}_copy" - engine_copy = rm.engines._send_create_engine( - engine=engine, - engine_revision=rm.engine_revisions.get_by_key(engine.latest_revision_key), - ) - assert engine_copy - - -def test_databases_get_many(rm_settings: Settings, database_name, engine_name): - rm = ResourceManager(rm_settings) - - # get all databases, at least one should be returned - databases = rm.databases.get_many() - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(name_contains=database_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(attached_engine_name_eq=engine_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} - - # get all databases, with name_contains - databases = rm.databases.get_many(attached_engine_name_contains=engine_name) - assert len(databases) > 0 - assert database_name in {db.name for db in databases} + engine.start() + assert engine.current_status == EngineStatus.RUNNING + engine.stop() + assert engine.current_status in {EngineStatus.STOPPING, EngineStatus.STOPPED} -def get_engine_params(rm: ResourceManager, engine: Engine): - engine_revision = rm.engine_revisions.get_by_key(engine.latest_revision_key) - instance_type = rm.instance_types.get_by_key( - engine_revision.specification.db_compute_instances_type_key - ) - - return { - "engine_type": engine.settings.preset, - "scale": engine_revision.specification.db_compute_instances_count, - "spec": instance_type.name, - "auto_stop": engine.settings.auto_stop_delay_duration, - "warmup": engine.settings.warm_up, - "description": engine.description, - } + engine.delete() + database.delete() ParamValue = namedtuple("ParamValue", "set expected") ENGINE_UPDATE_PARAMS = { - "engine_type": ParamValue( - EngineType.DATA_ANALYTICS, "ENGINE_SETTINGS_PRESET_DATA_ANALYTICS" - ), - "scale": ParamValue(23, 23), - "spec": ParamValue("B1", "B1"), - "auto_stop": ParamValue(123, "7380s"), - "warmup": ParamValue(WarmupMethod.PRELOAD_ALL_DATA, "ENGINE_SETTINGS_WARM_UP_ALL"), - "description": ParamValue("new db description", "new db description"), + # commented parameters are not available yet + # "scale": ParamValue(23, 23), + # "spec": ParamValue("B1", "B1"), + "auto_stop": ParamValue(123, 7380), + # "warmup": ParamValue(WarmupMethod.PRELOAD_ALL_DATA, WarmupMethod.PRELOAD_ALL_DATA), } -def test_engine_update_single_parameter(rm_settings: Settings, database_name: str): - rm = ResourceManager(rm_settings) +def test_engine_update_single_parameter( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "single_param") engine = rm.engines.create(name=name) - engine.attach_to_database(database=rm.databases.get_by_name(database_name)) + engine.attach_to_database(rm.databases.get(database_name)) assert engine.database.name == database_name for param, value in ENGINE_UPDATE_PARAMS.items(): engine.update(**{param: value.set}) - engine = rm.engines.get_by_name(name) - new_params = get_engine_params(rm, engine) - assert new_params[param] == value.expected + engine_new = rm.engines.get(name) + assert getattr(engine_new, param) == value.expected, f"Invalid {param} value" engine.delete() -def test_engine_update_multiple_parameters(rm_settings: Settings, database_name: str): - rm = ResourceManager(rm_settings) +def test_engine_update_multiple_parameters( + auth: Auth, + account_name: str, + api_endpoint: str, + database_name: str, +): + rm = ResourceManager( + auth=auth, account_name=account_name, api_endpoint=api_endpoint + ) name = make_engine_name(database_name, "multi_param") engine = rm.engines.create(name=name) - engine.attach_to_database(database=rm.databases.get_by_name(database_name)) + engine.attach_to_database(rm.databases.get(database_name)) assert engine.database.name == database_name engine.update( **dict({(param, value.set) for param, value in ENGINE_UPDATE_PARAMS.items()}) ) - engine = rm.engines.get_by_name(name) - new_params = get_engine_params(rm, engine) + engine_new = rm.engines.get(name) for param, value in ENGINE_UPDATE_PARAMS.items(): - assert new_params[param] == value.expected + assert getattr(engine_new, param) == value.expected, f"Invalid {param} value" engine.delete() diff --git a/tests/unit/async_db/conftest.py b/tests/unit/async_db/conftest.py index 8be96d4ccf5..3eaa5f12da2 100644 --- a/tests/unit/async_db/conftest.py +++ b/tests/unit/async_db/conftest.py @@ -5,7 +5,6 @@ from firebolt.async_db import ARRAY, DECIMAL, Connection, Cursor, connect from firebolt.client.auth import Auth -from firebolt.common.settings import Settings from tests.unit.db_conftest import * # noqa @@ -32,7 +31,7 @@ async def connection( @fixture -async def cursor(connection: Connection, settings: Settings) -> Cursor: +async def cursor(connection: Connection) -> Cursor: return connection.cursor() diff --git a/tests/unit/async_db/test_connection.py b/tests/unit/async_db/test_connection.py index 0f704597909..d2a8dfaa4da 100644 --- a/tests/unit/async_db/test_connection.py +++ b/tests/unit/async_db/test_connection.py @@ -8,7 +8,6 @@ from firebolt.async_db.connection import Connection, connect from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.settings import Settings from firebolt.utils.exception import ( ConfigurationError, ConnectionClosedError, @@ -56,7 +55,6 @@ async def test_cursors_closed_on_close(connection: Connection) -> None: async def test_cursor_initialized( - settings: Settings, mock_query: Callable, connection: Connection, python_query_data: List[List[ColType]], diff --git a/tests/unit/async_db/test_cursor.py b/tests/unit/async_db/test_cursor.py index 95cdb0b6b1c..b4d220fac7b 100644 --- a/tests/unit/async_db/test_cursor.py +++ b/tests/unit/async_db/test_cursor.py @@ -8,7 +8,6 @@ from firebolt.async_db import Cursor from firebolt.common._types import Column from firebolt.common.base_cursor import ColType, CursorState, QueryStatus -from firebolt.common.settings import Settings from firebolt.utils.exception import ( AsyncExecutionUnavailableError, CursorClosedError, @@ -192,9 +191,9 @@ async def test_cursor_execute( async def test_cursor_execute_error( httpx_mock: HTTPXMock, + server: str, query_url: str, get_engines_url: str, - settings: Settings, db_name: str, query_statistics: Dict[str, Any], cursor: Cursor, @@ -315,7 +314,7 @@ def http_error(*args, **kwargs): with raises(EngineNotRunningError) as excinfo: await query() assert cursor._state == CursorState.ERROR - assert settings.server in str(excinfo) + assert server in str(excinfo) # Engine does not exist httpx_mock.add_response( diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index 0db63da2dd2..2eee11b62aa 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -9,7 +9,6 @@ from firebolt.client import Client from firebolt.client.auth import Auth, ClientCredentials from firebolt.client.resource_manager_hooks import raise_on_4xx_5xx -from firebolt.common import Settings from firebolt.utils.token_storage import TokenSecureStorage from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema @@ -96,7 +95,7 @@ def test_client_account_id( account_id_callback: Callable, auth_url: str, auth_callback: Callable, - settings: Settings, + server: str, ): httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(auth_callback, url=auth_url) @@ -104,8 +103,8 @@ def test_client_account_id( with Client( account_name=account_name, auth=auth, - base_url=fix_url_schema(settings.server), - api_endpoint=settings.server, + base_url=fix_url_schema(server), + api_endpoint=server, ) as c: assert c.account_id == account_id, "Invalid account id returned" diff --git a/tests/unit/client/test_client_async.py b/tests/unit/client/test_client_async.py index acad0c4a286..a1eefddc970 100644 --- a/tests/unit/client/test_client_async.py +++ b/tests/unit/client/test_client_async.py @@ -7,7 +7,6 @@ from firebolt.client import AsyncClient from firebolt.client.auth import Auth -from firebolt.common import Settings from firebolt.utils.urls import AUTH_SERVICE_ACCOUNT_URL from firebolt.utils.util import fix_url_schema @@ -99,7 +98,7 @@ async def test_client_account_id( account_id_callback: Callable, auth_url: str, auth_callback: Callable, - settings: Settings, + server: str, ): httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(auth_callback, url=auth_url) @@ -107,7 +106,7 @@ async def test_client_account_id( async with AsyncClient( account_name=account_name, auth=auth, - base_url=fix_url_schema(settings.server), - api_endpoint=settings.server, + base_url=fix_url_schema(server), + api_endpoint=server, ) as c: assert await c.account_id == account_id, "Invalid account id returned." diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9910471dc29..961fd859908 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,5 +1,5 @@ from re import Pattern, compile -from typing import Callable, List +from typing import Callable import httpx from httpx import Request, Response @@ -7,9 +7,6 @@ from pytest import fixture from firebolt.client.auth import Auth, ClientCredentials -from firebolt.common.settings import Settings -from firebolt.model.provider import Provider -from firebolt.model.region import Region, RegionKey from firebolt.utils.exception import ( AccountNotFoundError, DatabaseError, @@ -88,61 +85,11 @@ def access_token_2() -> str: return "mock_access_token_2" -@fixture -def provider() -> Provider: - return Provider( - provider_id="mock_provider_id", - name="mock_provider_name", - ) - - -@fixture -def mock_providers(provider) -> List[Provider]: - return [provider] - - -@fixture -def region_1(provider) -> Region: - return Region( - key=RegionKey( - provider_id=provider.provider_id, - region_id="mock_region_id_1", - ), - name="mock_region_1", - ) - - -@fixture -def region_2(provider) -> Region: - return Region( - key=RegionKey( - provider_id=provider.provider_id, - region_id="mock_region_id_2", - ), - name="mock_region_2", - ) - - -@fixture -def mock_regions(region_1, region_2) -> List[Region]: - return [region_1, region_2] - - @fixture def auth(client_id: str, client_secret: str) -> Auth: return ClientCredentials(client_id, client_secret) -@fixture -def settings(server: str, region_1: str, auth: Auth, account_name: str) -> Settings: - return Settings( - server=server, - auth=auth, - default_region=region_1.name, - account_name=account_name, - ) - - @fixture def auth_callback(auth_url: str) -> Callable: def do_mock( @@ -169,14 +116,9 @@ def db_name() -> str: @fixture -def db_description() -> str: - return "database description" - - -@fixture -def account_id_url(settings: Settings, account_name: str) -> Pattern: +def account_id_url(server: str, account_name: str) -> Pattern: account_name_re = r"[^\\\\]*" - base = f"https://{settings.server}{ACCOUNT_BY_NAME_URL}" + base = f"https://{server}{ACCOUNT_BY_NAME_URL}" base = base.replace("/", "\\/").replace("?", "\\?") base = base.format(account_name=account_name_re) return compile(base) @@ -185,13 +127,13 @@ def account_id_url(settings: Settings, account_name: str) -> Pattern: @fixture def account_id_callback( account_id: str, - settings: Settings, + account_name: str, ) -> Callable: def do_mock( request: Request, **kwargs, ) -> Response: - if request.url.path.split("/")[-2] != settings.account_name: + if request.url.path.split("/")[-2] != account_name: raise AccountNotFoundError(request.url.path.split("/")[-2]) return Response(status_code=httpx.codes.OK, json={"id": account_id}) @@ -214,22 +156,20 @@ def engine_name() -> str: @fixture -def get_engine_name_by_id_url( - settings: Settings, account_id: str, engine_id: str -) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( +def get_engine_name_by_id_url(server: str, account_id: str, engine_id: str) -> str: + return f"https://{server}" + ACCOUNT_ENGINE_URL.format( account_id=account_id, engine_id=engine_id ) @fixture -def get_engines_url(settings: Settings) -> str: - return f"https://{settings.server}{ENGINES_URL}" +def get_engines_url(server: str) -> str: + return f"https://{server}{ENGINES_URL}" @fixture -def get_databases_url(settings: Settings) -> str: - return f"https://{settings.server}{DATABASES_URL}" +def get_databases_url(server: str) -> str: + return f"https://{server}{DATABASES_URL}" @fixture @@ -238,9 +178,9 @@ def database_id() -> str: @fixture -def database_by_name_url(settings: Settings, account_id: str, db_name: str) -> str: +def database_by_name_url(server: str, account_id: str, db_name: str) -> str: return ( - f"https://{settings.server}" + f"https://{server}" f"{ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id)}" f"?database_name={db_name}" ) diff --git a/tests/unit/db/conftest.py b/tests/unit/db/conftest.py index ec8d299f9f8..90bd7dc3080 100644 --- a/tests/unit/db/conftest.py +++ b/tests/unit/db/conftest.py @@ -32,9 +32,9 @@ def system_connection( db_name: str, auth: Auth, account_name: str, - mock_system_connection_flow: Callable, + mock_system_engine_connection_flow: Callable, ) -> Connection: - mock_system_connection_flow() + mock_system_engine_connection_flow() with connect( database=db_name, auth=auth, diff --git a/tests/unit/db/test_connection.py b/tests/unit/db/test_connection.py index 1399922558f..b58b4cfe78a 100644 --- a/tests/unit/db/test_connection.py +++ b/tests/unit/db/test_connection.py @@ -9,7 +9,6 @@ from firebolt.client.auth import Auth, ClientCredentials from firebolt.common._types import ColType -from firebolt.common.settings import Settings from firebolt.db import Connection, connect from firebolt.utils.exception import ( ConfigurationError, @@ -58,7 +57,6 @@ def test_cursors_closed_on_close(connection: Connection) -> None: def test_cursor_initialized( - settings: Settings, mock_query: Callable, connection: Connection, python_query_data: List[List[ColType]], diff --git a/tests/unit/db/test_cursor.py b/tests/unit/db/test_cursor.py index d61bc23a531..c94ba052d74 100644 --- a/tests/unit/db/test_cursor.py +++ b/tests/unit/db/test_cursor.py @@ -5,7 +5,6 @@ from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.common.settings import Settings from firebolt.db import Cursor from firebolt.db.cursor import ColType, Column, CursorState, QueryStatus from firebolt.utils.exception import ( @@ -188,7 +187,7 @@ def test_cursor_execute( def test_cursor_execute_error( httpx_mock: HTTPXMock, get_engines_url: str, - settings: Settings, + server: str, db_name: str, query_url: str, query_statistics: Dict[str, Any], @@ -309,7 +308,7 @@ def http_error(*args, **kwargs): with raises(EngineNotRunningError) as excinfo: query() assert cursor._state == CursorState.ERROR - assert settings.server in str(excinfo) + assert server in str(excinfo) # Engine does not exist httpx_mock.add_response( diff --git a/tests/unit/db/test_util.py b/tests/unit/db/test_util.py index a21ac69b73a..3f053bcd6b7 100644 --- a/tests/unit/db/test_util.py +++ b/tests/unit/db/test_util.py @@ -45,10 +45,17 @@ def test_is_db_not_available( def test_is_engine_running_system( + httpx_mock: HTTPXMock, system_connection: Connection, ): # System engine is always running assert is_engine_running(system_connection, "dummy") == True + # We didn't resolve account id since since we run no query + # We need to skip the mocked endpoint + httpx_mock.reset(False) + + # We haven't used account id endpoint since we didn't run any query, ignoring it + httpx_mock.reset(False) def test_is_engine_running( diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 5ed5a951758..b538a400ad3 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -8,7 +8,6 @@ from pytest_httpx import HTTPXMock from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column -from firebolt.common.settings import Settings from firebolt.db import ARRAY, DECIMAL from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME @@ -337,16 +336,16 @@ def set_params() -> Dict: @fixture -def query_url(settings: Settings, db_name: str) -> str: +def query_url(server: str, db_name: str) -> str: return URL( - f"https://{settings.server}/", + f"https://{server}/", params={"output_format": JSON_OUTPUT_FORMAT, "database": db_name}, ) @fixture -def set_query_url(settings: Settings, db_name: str) -> str: - return URL(f"https://{settings.server}/?database={db_name}") +def set_query_url(server: str, db_name: str) -> str: + return URL(f"https://{server}/?database={db_name}") @fixture @@ -494,16 +493,19 @@ def inner() -> None: @fixture -def mock_system_connection_flow( +def mock_system_engine_connection_flow( httpx_mock: HTTPXMock, auth_url: str, check_credentials_callback: Callable, get_system_engine_url: str, get_system_engine_callback: Callable, + account_id_url: str, + account_id_callback: Callable, ) -> Callable: def inner() -> None: httpx_mock.add_callback(check_credentials_callback, url=auth_url) httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) return inner diff --git a/tests/unit/service/conftest.py b/tests/unit/service/conftest.py index 4cb7a7789a2..0dcb74cd948 100644 --- a/tests/unit/service/conftest.py +++ b/tests/unit/service/conftest.py @@ -1,117 +1,91 @@ -import json +from dataclasses import dataclass, fields +from datetime import datetime from typing import Callable, List -from urllib.parse import urlparse import httpx from httpx import Response from pytest import fixture -from firebolt.common.settings import Settings -from firebolt.model.binding import Binding, BindingKey -from firebolt.model.database import Database, DatabaseKey -from firebolt.model.engine import Engine, EngineKey, EngineSettings -from firebolt.model.engine_revision import ( - EngineRevision, - EngineRevisionSpecification, -) -from firebolt.model.instance_type import InstanceType, InstanceTypeKey -from firebolt.model.region import Region -from firebolt.utils.urls import ( - ACCOUNT_BINDINGS_URL, - ACCOUNT_DATABASE_BINDING_URL, - ACCOUNT_DATABASE_BY_NAME_URL, - ACCOUNT_DATABASE_URL, - ACCOUNT_DATABASES_URL, - ACCOUNT_ENGINE_URL, - ACCOUNT_INSTANCE_TYPES_URL, - ACCOUNT_LIST_ENGINES_URL, - PROVIDERS_URL, - REGIONS_URL, -) +from firebolt.client.auth import Auth +from firebolt.common._types import _InternalType +from firebolt.model.database import Database +from firebolt.model.engine import Engine +from firebolt.model.instance_type import InstanceType +from firebolt.service.manager import ResourceManager +from firebolt.service.types import EngineStatus, EngineType, WarmupMethod +from firebolt.utils.urls import ACCOUNT_INSTANCE_TYPES_URL from tests.unit.util import list_to_paginated_response @fixture -def engine_name() -> str: - return "my_engine" +def region() -> str: + return "us-east-1" @fixture -def engine_scale() -> int: - return 2 - - -@fixture -def engine_settings() -> EngineSettings: - return EngineSettings.default() - - -@fixture -def mock_engine(engine_name, region_1, engine_settings, account_id, settings) -> Engine: +def mock_engine(region: str, server: str, instance_type_1: InstanceType) -> Engine: return Engine( - name=engine_name, - compute_region_key=region_1.key, - settings=engine_settings, - key=EngineKey(account_id=account_id, engine_id="mock_engine_id_1"), - endpoint=f"https://{settings.server}", + name="engine_1", + region=region, + spec=instance_type_1, + scale=2, + current_status=EngineStatus.STOPPED, + version="", + endpoint=server, + warmup=WarmupMethod.MINIMAL, + auto_stop=7200, + type=EngineType.GENERAL_PURPOSE, + provisioning="Finished", + _database_name="database", + _service=None, ) @fixture -def mock_engine_revision_spec( - instance_type_2, engine_scale -) -> EngineRevisionSpecification: - return EngineRevisionSpecification( - db_compute_instances_type_key=instance_type_2.key, - db_compute_instances_count=engine_scale, - proxy_instances_type_key=instance_type_2.key, - ) - - -@fixture -def mock_engine_revision(mock_engine_revision_spec) -> EngineRevision: - return EngineRevision(specification=mock_engine_revision_spec) - - -@fixture -def instance_type_1(provider, region_1) -> InstanceType: +def instance_type_1() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_1.key.region_id, - instance_type_id="instance_type_id_1", - ), name="B1", - price_per_hour_cents=10, + price_per_hour_cents=40, storage_size_bytes=0, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @fixture -def instance_type_2(provider, region_2) -> InstanceType: +def instance_type_2() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_2.key.region_id, - instance_type_id="instance_type_id_2", - ), name="B2", price_per_hour_cents=20, storage_size_bytes=500, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @fixture -def instance_type_3(provider, region_2) -> InstanceType: +def instance_type_3() -> InstanceType: return InstanceType( - key=InstanceTypeKey( - provider_id=provider.provider_id, - region_id=region_2.key.region_id, - instance_type_id="instance_type_id_2", - ), - name="B2", + name="B3", price_per_hour_cents=30, storage_size_bytes=500, + is_spot_available=True, + cpu_virtual_cores_count=0, + memory_size_bytes=0, + create_time=datetime.now().isoformat(), + last_update_time=datetime.now().isoformat(), + _key={}, + _service=None, ) @@ -128,43 +102,37 @@ def mock_instance_types( @fixture -def provider_callback(provider_url: str, mock_providers) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == provider_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_providers), - ) - - return do_mock - - -@fixture -def provider_url(settings: Settings) -> str: - return f"https://{settings.server}{PROVIDERS_URL}" - - -@fixture -def region_callback(region_url: str, mock_regions) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == region_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_regions), - ) - - return do_mock +def mock_database(region: str) -> Database: + return Database( + name="database", + description="mock_db_description", + region=region, + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, + ) @fixture -def region_url(settings: Settings) -> str: - return f"https://{settings.server}{REGIONS_URL}?page.first=5000" +def mock_database_2(region: str) -> Database: + return Database( + name="database2", + description="completely different db", + region=region, + data_size_full=0, + data_size_compressed=0, + create_time=datetime.now().isoformat(), + create_actor="", + _status="", + _attached_engine_names="-", + _errors="", + _service=None, + ) @fixture @@ -182,23 +150,6 @@ def do_mock( return do_mock -@fixture -def instance_type_region_1_callback( - instance_type_region_1_url: str, mock_instance_types -) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == instance_type_region_1_url - return Response( - status_code=httpx.codes.OK, - json=list_to_paginated_response(mock_instance_types), - ) - - return do_mock - - @fixture def instance_type_empty_callback() -> Callable: def do_mock( @@ -214,305 +165,199 @@ def do_mock( @fixture -def instance_type_url(settings: Settings, account_id: str) -> str: +def instance_type_url(server: str, account_id: str) -> str: return ( - f"https://{settings.server}" + f"https://{server}" + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) + "?page.first=5000" ) -@fixture -def instance_type_region_1_url( - settings: Settings, region_1: Region, account_id: str -) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_region_id_eq={region_1.key.region_id}" - ) - - -@fixture -def instance_type_region_2_url( - settings: Settings, region_2: Region, account_id: str -) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_INSTANCE_TYPES_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_region_id_eq={region_2.key.region_id}" - ) +empty_response = { + "meta": [], + "data": [], + "rows": 0, + "statistics": { + "elapsed": 39635.785423446, + "rows_read": 0, + "bytes_read": 0, + "time_before_execution": 0, + "time_to_execute": 0, + }, +} -@fixture -def engine_callback(engine_url: str, mock_engine) -> Callable: +def get_objects_from_db_callback(objs: List[dataclass]) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert urlparse(engine_url).path in request.url.path + fieldname = lambda f: (f.metadata or {}).get("db_name", f.name) + types = { + "int": _InternalType.Long.value, + "str": _InternalType.Text.value, + "datetime": _InternalType.Text.value, # we receive datetime as text from db + "InstanceType": _InternalType.Text.value, + "EngineType": _InternalType.Text.value, + "EngineStatus": _InternalType.Text.value, + "WarmupMethod": _InternalType.Text.value, + } + dc_fields = [f for f in fields(objs[0]) if f.name != "_service"] + + def get_obj_field(obj, f): + value = getattr(obj, f.name) + if isinstance(value, (InstanceType, EngineStatus)): + return str(value) + if isinstance(value, (EngineType, WarmupMethod)): + return " ".join( + map(lambda s: s.capitalize(), str(value).lower().split("_")) + ) + return value + + query_response = { + "meta": [{"name": fieldname(f), "type": types[f.type]} for f in dc_fields], + "data": [[get_obj_field(obj, f) for f in dc_fields] for obj in objs], + "rows": len(objs), + "statistics": { + "elapsed": 0.116907717, + "rows_read": 1, + "bytes_read": 61, + "time_before_execution": 0.012180623, + "time_to_execute": 0.104614307, + "scanned_bytes_cache": 0, + "scanned_bytes_storage": 0, + }, + } return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json=query_response, ) return do_mock @fixture -def engine_url(settings: Settings, account_id) -> str: - return f"https://{settings.server}" + ACCOUNT_LIST_ENGINES_URL.format( - account_id=account_id - ) +def get_engine_callback(mock_engine: Engine) -> Callable: + return get_objects_from_db_callback([mock_engine]) @fixture -def account_engine_callback(account_engine_url: str, mock_engine) -> Callable: +def get_engine_not_found_callback(mock_engine: Engine) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == account_engine_url return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json=empty_response, ) return do_mock @fixture -def account_engine_url(settings: Settings, account_id, mock_engine) -> str: - return f"https://{settings.server}" + ACCOUNT_ENGINE_URL.format( - account_id=account_id, - engine_id=mock_engine.engine_id, - ) - - -@fixture -def mock_database(region_1: str, account_id: str) -> Database: - return Database( - name="database", - description="mock_db_description", - compute_region_key=region_1.key, - database_key=DatabaseKey( - account_id=account_id, database_id="mock_database_id_1" - ), - ) - - -@fixture -def create_databases_callback(databases_url: str, mock_database) -> Callable: +def attach_engine_to_db_callback(system_engine_no_db_query_url: str) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - database_properties = json.loads(request.read().decode("utf-8"))["database"] - - mock_database.name = database_properties["name"] - mock_database.description = database_properties["description"] - - assert request.url == databases_url - return Response( - status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, - ) - - return do_mock - - -@fixture -def databases_get_callback(databases_url: str, mock_database) -> Callable: - def get_databases_callback_inner( - request: httpx.Request = None, **kwargs - ) -> Response: - return Response( - status_code=httpx.codes.OK, json={"edges": [{"node": mock_database.dict()}]} - ) - - return get_databases_callback_inner - - -@fixture -def databases_url(settings: Settings, account_id: str) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASES_URL.format( - account_id=account_id - ) - - -@fixture -def database_callback(database_url: str, mock_database) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == database_url - return Response( - status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, - ) - - return do_mock - - -@fixture -def database_not_found_callback(database_url: str) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == database_url - return Response( - status_code=httpx.codes.OK, - json={}, - ) - - return do_mock - - -@fixture -def database_url(settings: Settings, account_id: str, mock_database) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_URL.format( - account_id=account_id, database_id=mock_database.database_id - ) - - -@fixture -def database_get_by_name_callback(database_get_by_name_url, mock_database) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == database_get_by_name_url + assert request.url == system_engine_no_db_query_url return Response( status_code=httpx.codes.OK, - json={"database_id": {"database_id": mock_database.database_id}}, + json=empty_response, ) return do_mock @fixture -def database_get_by_name_url(settings: Settings, account_id: str, mock_database) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=account_id) - + f"?database_name={mock_database.name}" - ) +def updated_engine_scale() -> int: + return 10 @fixture -def database_update_callback(database_get_url, mock_database) -> Callable: +def update_engine_callback( + system_engine_no_db_query_url: str, + mock_engine: Engine, + updated_engine_scale: int, +) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - database_properties = json.loads(request.read().decode("utf-8"))["database"] - - assert request.url == database_get_url + assert request.url == system_engine_no_db_query_url + mock_engine.scale = updated_engine_scale return Response( status_code=httpx.codes.OK, - json={"database": database_properties}, + json=empty_response, ) return do_mock @fixture -def database_get_callback(database_get_url, mock_database) -> Callable: +def create_databases_callback( + system_engine_no_db_query_url: str, mock_database +) -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == database_get_url + assert request.url == system_engine_no_db_query_url return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json=empty_response, ) return do_mock -# duplicates database_url -@fixture -def database_get_url(settings: Settings, account_id: str, mock_database) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_URL.format( - account_id=account_id, database_id=mock_database.database_id - ) - - @fixture -def binding(account_id, mock_engine, mock_database) -> Binding: - return Binding( - binding_key=BindingKey( - account_id=account_id, - database_id=mock_database.database_id, - engine_id=mock_engine.engine_id, - ), - is_default_engine=True, - ) +def databases_get_callback( + mock_database: Database, mock_database_2: Database +) -> Callable: + return get_objects_from_db_callback([mock_database, mock_database_2]) @fixture -def bindings_callback(bindings_url: str, binding: Binding) -> Callable: +def database_not_found_callback() -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == bindings_url return Response( status_code=httpx.codes.OK, - json=list_to_paginated_response([binding]), + json=empty_response, ) return do_mock @fixture -def no_bindings_callback(bindings_url: str) -> Callable: +def database_update_callback() -> Callable: def do_mock( request: httpx.Request = None, **kwargs, ) -> Response: - assert request.url == bindings_url return Response( status_code=httpx.codes.OK, - json=list_to_paginated_response([]), + json=empty_response, ) return do_mock @fixture -def bindings_url(settings: Settings, account_id: str, mock_engine: Engine) -> str: - return ( - f"https://{settings.server}" - + ACCOUNT_BINDINGS_URL.format(account_id=account_id) - + f"?page.first=5000&filter.id_engine_id_eq={mock_engine.engine_id}" - ) - - -@fixture -def create_binding_callback(create_binding_url: str, binding) -> Callable: - def do_mock( - request: httpx.Request = None, - **kwargs, - ) -> Response: - assert request.url == create_binding_url - return Response( - status_code=httpx.codes.OK, - json={"binding": binding.dict()}, - ) - - return do_mock +def database_get_callback(mock_database) -> Callable: + return get_objects_from_db_callback([mock_database]) @fixture -def create_binding_url( - settings: Settings, account_id: str, mock_database: Database, mock_engine: Engine -) -> str: - return f"https://{settings.server}" + ACCOUNT_DATABASE_BINDING_URL.format( - account_id=account_id, - database_id=mock_database.database_id, - engine_id=mock_engine.engine_id, - ) +def resource_manager( + auth: Auth, + account_name: str, + server: str, + mock_system_engine_connection_flow: Callable, +) -> ResourceManager: + mock_system_engine_connection_flow() + return ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) diff --git a/tests/unit/service/test_database.py b/tests/unit/service/test_database.py index 35dec331ad6..df7038440ec 100644 --- a/tests/unit/service/test_database.py +++ b/tests/unit/service/test_database.py @@ -1,131 +1,92 @@ -from re import Pattern, compile from typing import Callable from pytest_httpx import HTTPXMock -from firebolt.common import Settings from firebolt.model.database import Database +from firebolt.model.engine import Engine from firebolt.service.manager import ResourceManager def test_database_create( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, + resource_manager: ResourceManager, + database_get_callback: Callable, create_databases_callback: Callable, - databases_url: str, - db_name: str, - db_description: str, + system_engine_no_db_query_url: str, + mock_database: Database, + mock_engine: Engine, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") + httpx_mock.add_callback( + create_databases_callback, url=system_engine_no_db_query_url, method="POST" + ) + httpx_mock.add_callback( + database_get_callback, url=system_engine_no_db_query_url, method="POST" + ) - manager = ResourceManager(settings=settings) - database = manager.databases.create(name=db_name, description=db_description) + database = resource_manager.databases.create( + name=mock_database.name, + region=mock_database.region, + attached_engines=[mock_engine], + description=mock_database.description, + ) - assert database.name == db_name - assert database.description == db_description + assert database == mock_database -def test_database_get_by_name( +def test_database_get( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, - database_get_by_name_callback: Callable, - database_get_by_name_url: str, + resource_manager: ResourceManager, database_get_callback: Callable, - database_get_url: str, + system_engine_no_db_query_url: str, mock_database: Database, ): + httpx_mock.add_callback( + database_get_callback, url=system_engine_no_db_query_url, method="POST" + ) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(database_get_by_name_callback, url=database_get_by_name_url) - httpx_mock.add_callback(database_get_callback, url=database_get_url) - - manager = ResourceManager(settings=settings) - database = manager.databases.get_by_name(name=mock_database.name) + database = resource_manager.databases.get(mock_database.name) - assert database.name == mock_database.name + assert database == mock_database def test_database_get_many( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, - database_get_by_name_callback: Callable, - database_get_by_name_url: str, + resource_manager: ResourceManager, databases_get_callback: Callable, - databases_url: str, + system_engine_no_db_query_url: str, mock_database: Database, + mock_database_2: Database, ): - - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback( databases_get_callback, - url=compile(databases_url + "?[a-zA-Z0-9=&]*"), - method="GET", + url=system_engine_no_db_query_url, + method="POST", ) - manager = ResourceManager(settings=settings) - databases = manager.databases.get_many( + databases = resource_manager.databases.get_many( name_contains=mock_database.name, attached_engine_name_eq="mockengine", attached_engine_name_contains="mockengine", + region_eq="us-east-1", ) - assert len(databases) == 1 - assert databases[0].name == mock_database.name + assert len(databases) == 2 + assert databases[0] == mock_database + assert databases[1] == mock_database_2 def test_database_update( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, + resource_manager: ResourceManager, database_update_callback: Callable, - database_url: str, + system_engine_no_db_query_url: str, mock_database: Database, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - - httpx_mock.add_callback(database_update_callback, url=database_url, method="PATCH") - - manager = ResourceManager(settings=settings) + httpx_mock.add_callback( + database_update_callback, url=system_engine_no_db_query_url, method="POST" + ) - mock_database._service = manager + mock_database._service = resource_manager.databases database = mock_database.update(description="new description") assert database.description == "new description" diff --git a/tests/unit/service/test_engine.py b/tests/unit/service/test_engine.py index 9686ecbf7ff..0097153c8bb 100644 --- a/tests/unit/service/test_engine.py +++ b/tests/unit/service/test_engine.py @@ -1,264 +1,73 @@ -from re import Pattern -from typing import Callable, List +from typing import Callable -from pydantic import ValidationError from pytest import raises from pytest_httpx import HTTPXMock -from firebolt.common import Settings -from firebolt.model.engine import Engine, _EngineCreateRequest -from firebolt.model.engine_revision import EngineRevision -from firebolt.model.instance_type import InstanceType -from firebolt.model.region import Region +from firebolt.model.database import Database +from firebolt.model.engine import Engine from firebolt.service.manager import ResourceManager -from firebolt.utils.exception import FireboltError, NoAttachedDatabaseError +from firebolt.utils.exception import ( + EngineNotFoundError, + NoAttachedDatabaseError, +) def test_engine_create( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, + system_engine_no_db_query_url: str, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) - - assert engine.name == engine_name - - -def test_engine_create_with_kwargs( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_id: str, - mock_engine_revision: EngineRevision, -): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - # Setting to manager.engines.create defaults - mock_engine.key = None - mock_engine.description = "" - mock_engine.endpoint = None - # Testing kwargs - mock_engine.settings.minimum_logging_level = "ENGINE_SETTINGS_LOGGING_LEVEL_DEBUG" - mock_engine_revision.specification.proxy_version = "0.2.3" - engine_content = _EngineCreateRequest( - account_id=account_id, engine=mock_engine, engine_revision=mock_engine_revision - ) - httpx_mock.add_callback( - engine_callback, - url=engine_url, - method="POST", - match_content=engine_content.json(by_alias=True).encode("ascii"), - ) - manager = ResourceManager(settings=settings) - engine_settings_kwargs = { - "minimum_logging_level": "ENGINE_SETTINGS_LOGGING_LEVEL_DEBUG" - } - revision_spec_kwargs = {"proxy_version": "0.2.3"} - engine = manager.engines.create( - name=engine_name, - engine_settings_kwargs=engine_settings_kwargs, - revision_spec_kwargs=revision_spec_kwargs, + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + + engine = resource_manager.engines.create( + name=mock_engine.name, + region=mock_engine.region, + engine_type=mock_engine.type, + spec=mock_engine.spec, + scale=mock_engine.scale, + auto_stop=mock_engine.auto_stop, + warmup=mock_engine.warmup, ) - assert engine.name == engine_name + assert engine == mock_engine -def test_engine_create_with_kwargs_fail( +def test_engine_not_found( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, + resource_manager: ResourceManager, + get_engine_not_found_callback: Callable, + system_engine_no_db_query_url: str, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + get_engine_not_found_callback, url=system_engine_no_db_query_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - manager = ResourceManager(settings=settings) - revision_spec_kwargs = {"incorrect_kwarg": "val"} - with raises(ValidationError): - manager.engines.create( - name=engine_name, revision_spec_kwargs=revision_spec_kwargs - ) - - engine_settings_kwargs = {"incorrect_kwarg": "val"} - with raises(TypeError): - manager.engines.create( - name=engine_name, engine_settings_kwargs=engine_settings_kwargs - ) - - -def test_engine_create_no_available_types( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_empty_callback: Callable, - instance_type_region_2_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_url: str, - region_2: Region, -): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_empty_callback, url=instance_type_region_2_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - - manager = ResourceManager(settings=settings) - - with raises(FireboltError): - manager.engines.create(name=engine_name, region=region_2) + with raises(EngineNotFoundError): + resource_manager.engines.get("invalid name") def test_engine_no_attached_database( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_engine_callback: Callable, - account_engine_url: str, - database_callback: Callable, - database_url: str, - no_bindings_callback: Callable, - bindings_url: str, -): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(no_bindings_callback, url=bindings_url) - - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) - - with raises(NoAttachedDatabaseError): - engine.start() - - -def test_engine_start_binding_to_missing_database( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, database_not_found_callback: Callable, - database_url: str, - bindings_callback: Callable, - bindings_url: str, + system_engine_no_db_query_url: str, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + database_not_found_callback, url=system_engine_no_db_query_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(database_not_found_callback, url=database_url) - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) + engine = resource_manager.engines.get("engine_name") with raises(NoAttachedDatabaseError): engine.start() @@ -266,182 +75,74 @@ def test_engine_start_binding_to_missing_database( def test_get_connection( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, - mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - db_name: str, - database_callback: Callable, - database_url: str, - bindings_callback: Callable, - bindings_url: str, - mock_connection_flow: Callable, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + get_engine_callback: Callable, + database_get_callback: Callable, + system_engine_no_db_query_url: str, + system_engine_query_url: str, + get_engine_url_callback: Callable, + mock_query: Callable, ): - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - httpx_mock.add_callback(bindings_callback, url=bindings_url) + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(database_get_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_url_callback, url=system_engine_query_url) + mock_query() - httpx_mock.add_callback(database_callback, url=database_url) - mock_connection_flow() - - manager = ResourceManager(settings=settings) - engine = manager.engines.create(name=engine_name) + engine = resource_manager.engines.get("engine_name") with engine.get_connection() as connection: - assert connection + connection.cursor().execute("select 1") def test_attach_to_database( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - settings: Settings, - account_id_callback: Callable, - account_id_url: Pattern, - create_databases_callback: Callable, - databases_url: str, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, + mock_database: Database, + mock_engine: Engine, + get_engine_callback: Callable, database_get_callback: Callable, - database_get_url: str, - database_not_found_callback: Callable, - database_url: str, - db_name: str, - engine_name: str, - engine_callback: Callable, - engine_url: str, - create_binding_callback: Callable, - create_binding_url: str, - bindings_callback: Callable, - bindings_url: str, + attach_engine_to_db_callback: Callable, + system_engine_no_db_query_url: str, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(database_get_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url + attach_engine_to_db_callback, url=system_engine_no_db_query_url ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") - httpx_mock.add_callback(database_not_found_callback, url=database_url, method="GET") - - # create engine - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(engine_callback, url=engine_url, method="POST") - # attach - httpx_mock.add_callback(database_get_callback, url=database_get_url) - httpx_mock.add_callback( - create_binding_callback, url=create_binding_url, method="POST" - ) + database = resource_manager.databases.get("database") + engine = resource_manager.engines.get("engine") - manager = ResourceManager(settings=settings) - database = manager.databases.create(name=db_name) + engine._service = resource_manager.engines - engine = manager.engines.create(name=engine_name) - engine.attach_to_database(database=database) + engine.attach_to_database(database) - assert engine.database == database + assert engine._database_name == database.name def test_engine_update( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - instance_type_region_1_callback: Callable, - instance_type_region_1_url: str, - region_callback: Callable, - region_url: str, - settings: Settings, - mock_instance_types: List[InstanceType], - mock_regions, + resource_manager: ResourceManager, + instance_type_callback: Callable, + instance_type_url: str, mock_engine: Engine, - engine_name: str, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - engine_url: str, - account_engine_url: str, - account_engine_callback: Callable, + get_engine_callback: Callable, + update_engine_callback: Callable, + system_engine_no_db_query_url: str, + updated_engine_scale: int, ): + httpx_mock.add_callback(instance_type_callback, url=instance_type_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(update_engine_callback, url=system_engine_no_db_query_url) + httpx_mock.add_callback(get_engine_callback, url=system_engine_no_db_query_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - # - httpx_mock.add_callback( - account_engine_callback, url=account_engine_url, method="PATCH" - ) - manager = ResourceManager(settings=settings) - - mock_engine._service = manager.engines - engine = mock_engine.update( - name="new_engine_name", description="new engine description" - ) - - assert engine.name == "new_engine_name" - assert engine.description == "new engine description" - - -def test_engine_restart( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - settings: Settings, - mock_engine: Engine, - account_id_callback: Callable, - account_id_url: Pattern, - engine_callback: Callable, - account_engine_url: str, - bindings_callback: Callable, - bindings_url: str, - database_callback: Callable, - database_url: str, -): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - - httpx_mock.add_callback(account_id_callback, url=account_id_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - - httpx_mock.add_callback( - engine_callback, url=f"{account_engine_url}:restart", method="POST" - ) - httpx_mock.add_callback(bindings_callback, url=bindings_url) - httpx_mock.add_callback(database_callback, url=database_url) - - manager = ResourceManager(settings=settings) - - mock_engine._service = manager.engines - engine = mock_engine.restart(wait_for_startup=False) + mock_engine._service = resource_manager.engines + mock_engine.update(scale=updated_engine_scale) - assert engine.name == mock_engine.name + assert mock_engine.scale == updated_engine_scale diff --git a/tests/unit/service/test_instance_type.py b/tests/unit/service/test_instance_type.py index dce917c1f42..2c2e0610fba 100644 --- a/tests/unit/service/test_instance_type.py +++ b/tests/unit/service/test_instance_type.py @@ -1,50 +1,21 @@ -from re import Pattern from typing import Callable, List from pytest_httpx import HTTPXMock -from firebolt.common import Settings from firebolt.model.instance_type import InstanceType -from firebolt.model.region import Region from firebolt.service.manager import ResourceManager def test_instance_type( httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, + resource_manager: ResourceManager, instance_type_callback: Callable, - instance_type_region_1_callback: Callable, - instance_type_empty_callback: Callable, instance_type_url: str, - instance_type_region_1_url: str, - instance_type_region_2_url: str, - account_id_callback: Callable, - account_id_url: Pattern, - settings: Settings, mock_instance_types: List[InstanceType], cheapest_instance: InstanceType, - region_1: Region, - region_2: Region, + mock_system_engine_connection_flow: Callable, ): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(auth_callback, url=auth_url) httpx_mock.add_callback(instance_type_callback, url=instance_type_url) - httpx_mock.add_callback( - instance_type_region_1_callback, url=instance_type_region_1_url - ) - httpx_mock.add_callback( - instance_type_empty_callback, url=instance_type_region_2_url - ) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - manager = ResourceManager(settings=settings) - assert manager.instance_types.instance_types == mock_instance_types - assert ( - manager.instance_types.cheapest_instance_in_region(region_1) - == cheapest_instance - ) - assert not manager.instance_types.cheapest_instance_in_region(region_2) + assert resource_manager.instance_types.instance_types == mock_instance_types + assert resource_manager.instance_types.cheapest_instance == cheapest_instance diff --git a/tests/unit/service/test_region.py b/tests/unit/service/test_region.py deleted file mode 100644 index 939238f6bfe..00000000000 --- a/tests/unit/service/test_region.py +++ /dev/null @@ -1,31 +0,0 @@ -from re import Pattern -from typing import Callable, List - -from pytest_httpx import HTTPXMock - -from firebolt.common import Settings -from firebolt.model.region import Region -from firebolt.service.manager import ResourceManager - - -def test_region( - httpx_mock: HTTPXMock, - auth_callback: Callable, - auth_url: str, - provider_callback: Callable, - provider_url: str, - region_callback: Callable, - region_url: str, - account_id_callback: Callable, - account_id_url: Pattern, - settings: Settings, - mock_regions: List[Region], -): - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(auth_callback, url=auth_url) - httpx_mock.add_callback(region_callback, url=region_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - - manager = ResourceManager(settings=settings) - assert manager.regions.regions == mock_regions diff --git a/tests/unit/service/test_resource_manager.py b/tests/unit/service/test_resource_manager.py index 9ddfeacb468..8afa351474b 100644 --- a/tests/unit/service/test_resource_manager.py +++ b/tests/unit/service/test_resource_manager.py @@ -6,91 +6,71 @@ from pytest_httpx import HTTPXMock from firebolt.client.auth import Auth, ClientCredentials -from firebolt.common.settings import Settings from firebolt.service.manager import ResourceManager from firebolt.utils.exception import AccountNotFoundError from firebolt.utils.token_storage import TokenSecureStorage +from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME def test_rm_credentials( httpx_mock: HTTPXMock, + auth: Auth, + account_name: str, + server: str, check_token_callback: Callable, - check_credentials_callback: Callable, - settings: Settings, - auth_url: str, - account_id_url: Pattern, - account_id_callback: Callable, - provider_callback: Callable, - provider_url: str, - access_token: str, + mock_system_engine_connection_flow: Callable, ) -> None: """Credentials, that are passed to rm are processed properly.""" url = "https://url" - httpx_mock.add_callback(check_credentials_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) - rm = ResourceManager(settings) - rm.client.get(url) + rm = ResourceManager(auth=auth, account_name=account_name, api_endpoint=server) + rm._client.get(url) @mark.nofakefs def test_rm_token_cache( httpx_mock: HTTPXMock, check_token_callback: Callable, - check_credentials_callback: Callable, - settings: Settings, - auth_url: str, + auth: Auth, + server: str, account_name: str, - account_id_url: Pattern, - account_id_callback: Callable, - provider_callback: Callable, - provider_url: str, access_token: str, + mock_system_engine_connection_flow: Callable, ) -> None: """Credentials, that are passed to rm are cached properly.""" url = "https://url" - httpx_mock.add_callback(check_credentials_callback, url=auth_url) + mock_system_engine_connection_flow() httpx_mock.add_callback(check_token_callback, url=url) - httpx_mock.add_callback(provider_callback, url=provider_url) - httpx_mock.add_callback(account_id_callback, url=account_id_url) with Patcher(): - local_settings = Settings( - account_name=account_name, + rm = ResourceManager( auth=ClientCredentials( - settings.auth.client_id, - settings.auth.client_secret, - use_token_cache=True, + auth.client_id, auth.client_secret, use_token_cache=True ), - server=settings.server, - default_region=settings.default_region, + account_name=account_name, + api_endpoint=server, ) - rm = ResourceManager(local_settings) - rm.client.get(url) + rm._client.get(url) - ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) + ts = TokenSecureStorage(auth.client_id, auth.client_secret) assert ts.get_cached_token() == access_token, "Invalid token value cached" # Do the same, but with use_token_cache=False with Patcher(): - local_settings = Settings( - account_name=account_name, + rm = ResourceManager( auth=ClientCredentials( - settings.auth.client_id, - settings.auth.client_secret, - use_token_cache=False, + auth.client_id, auth.client_secret, use_token_cache=False ), - server=settings.server, - default_region=settings.default_region, + account_name=account_name, + api_endpoint=server, ) - rm = ResourceManager(local_settings) - rm.client.get(url) + rm._client.get(url) - ts = TokenSecureStorage(settings.auth.client_id, settings.auth.client_secret) + ts = TokenSecureStorage(auth.client_id, auth.client_secret) assert ( ts.get_cached_token() is None ), "Token is cached even though caching is disabled" @@ -99,22 +79,22 @@ def test_rm_token_cache( def test_rm_invalid_account_name( httpx_mock: HTTPXMock, auth: Auth, - settings: Settings, - check_credentials_callback: Callable, + server: str, auth_url: str, + check_credentials_callback: Callable, account_id_url: Pattern, account_id_callback: Callable, + get_system_engine_callback: Callable, ) -> None: """Resource manager raises an error on invalid account name.""" + get_system_engine_url = ( + f"https://{server}" + f"{GATEWAY_HOST_BY_ACCOUNT_NAME.format(account_name='invalid')}" + ) + httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(get_system_engine_callback, url=get_system_engine_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) - local_settings = Settings( - auth=auth, - account_name="invalid", - server=settings.server, - default_region=settings.default_region, - ) - with raises(AccountNotFoundError): - ResourceManager(local_settings) + ResourceManager(auth=auth, account_name="invalid", api_endpoint=server) diff --git a/tests/unit/util.py b/tests/unit/util.py index 98b72304fc1..7c34650f15b 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -1,3 +1,4 @@ +from dataclasses import Field, dataclass, fields from typing import AsyncGenerator, Dict, Generator, List from httpx import Request, Response @@ -7,8 +8,16 @@ from firebolt.model import FireboltBaseModel +def field_name(f: Field) -> str: + return (f.metadata or {}).get("db_name", f.name) + + +def to_dict(dc: dataclass) -> Dict: + return {field_name(f): getattr(dc, f.name) for f in fields(dc)} + + def list_to_paginated_response(items: List[FireboltBaseModel]) -> Dict: - return {"edges": [{"node": i.dict()} for i in items]} + return {"edges": [{"node": to_dict(i)} for i in items]} def execute_generator_requests( diff --git a/tests/unit/utils/test_usage_tracker.py b/tests/unit/utils/test_usage_tracker.py index 1ab2f72db0c..2949b49757e 100644 --- a/tests/unit/utils/test_usage_tracker.py +++ b/tests/unit/utils/test_usage_tracker.py @@ -1,7 +1,6 @@ from collections import namedtuple from unittest.mock import MagicMock, patch -from pydantic import ValidationError from pytest import mark, raises from firebolt.utils.usage_tracker import ( @@ -199,5 +198,5 @@ def test_user_agent(drivers, clients, expected_string): MagicMock(return_value=("1", "2", "Win", "ciso")), ) def test_incorrect_user_agent(drivers, clients): - with raises(ValidationError): + with raises(ValueError): get_user_agent_header(drivers, clients)