From 189f34a8c4f1447282ceefd49fb7d547808dbd84 Mon Sep 17 00:00:00 2001 From: Eric Gustavson Date: Fri, 8 Oct 2021 10:52:04 -0700 Subject: [PATCH] No longer hardcode provider --- src/firebolt/common/constants.py | 1 - src/firebolt/model/instance_type.py | 3 +- src/firebolt/model/provider.py | 16 ++++++++++ src/firebolt/model/region.py | 3 +- src/firebolt/service/manager.py | 2 ++ src/firebolt/service/provider.py | 9 ++++++ src/firebolt/service/region.py | 6 +++- tests/conftest.py | 47 ++++++++++++++++++++++++++--- tests/model/test_instance_type.py | 3 ++ tests/model/test_region.py | 3 ++ 10 files changed, 83 insertions(+), 10 deletions(-) delete mode 100644 src/firebolt/common/constants.py create mode 100644 src/firebolt/model/provider.py create mode 100644 src/firebolt/service/provider.py diff --git a/src/firebolt/common/constants.py b/src/firebolt/common/constants.py deleted file mode 100644 index b2d39bd93dd..00000000000 --- a/src/firebolt/common/constants.py +++ /dev/null @@ -1 +0,0 @@ -AWS_PROVIDER_ID: str = "402a51bb-1c8e-4dc4-9e05-ced3c1e2186e" diff --git a/src/firebolt/model/instance_type.py b/src/firebolt/model/instance_type.py index f36e814c4b4..ed99fcfda9f 100644 --- a/src/firebolt/model/instance_type.py +++ b/src/firebolt/model/instance_type.py @@ -3,12 +3,11 @@ from pydantic import Field -from firebolt.common.constants import AWS_PROVIDER_ID from firebolt.model import FireboltBaseModel class InstanceTypeKey(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str = AWS_PROVIDER_ID + provider_id: str region_id: str instance_type_id: str diff --git a/src/firebolt/model/provider.py b/src/firebolt/model/provider.py new file mode 100644 index 00000000000..242856ef218 --- /dev/null +++ b/src/firebolt/model/provider.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from firebolt.model import FireboltBaseModel + + +class Provider(FireboltBaseModel, frozen=True): # type: ignore + provider_id: str = Field(alias="id") + name: str + + # optional + create_time: Optional[datetime] + display_name: Optional[str] + last_update_time: Optional[datetime] diff --git a/src/firebolt/model/region.py b/src/firebolt/model/region.py index c474ed87f3e..c3b476da26c 100644 --- a/src/firebolt/model/region.py +++ b/src/firebolt/model/region.py @@ -3,12 +3,11 @@ from pydantic import Field -from firebolt.common.constants import AWS_PROVIDER_ID from firebolt.model import FireboltBaseModel class RegionKey(FireboltBaseModel, frozen=True): # type: ignore - provider_id: str = AWS_PROVIDER_ID + provider_id: str region_id: str diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index 4b1cbc2663c..fe2f775eb16 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -2,6 +2,7 @@ from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx from firebolt.common import Settings +from firebolt.service.provider import get_provider_id class ResourceManager: @@ -47,6 +48,7 @@ def _init_services(self, default_region_name: str) -> None: resource_manager=self, default_region_name=default_region_name ) self.instance_types = InstanceTypeService(resource_manager=self) + self.provider_id = get_provider_id(client=self.client) # Firebolt Resources self.databases = DatabaseService(resource_manager=self) diff --git a/src/firebolt/service/provider.py b/src/firebolt/service/provider.py new file mode 100644 index 00000000000..47c463b031c --- /dev/null +++ b/src/firebolt/service/provider.py @@ -0,0 +1,9 @@ +from firebolt.client import Client +from firebolt.model.provider import Provider + + +def get_provider_id(client: Client) -> str: + """Get the AWS provider_id.""" + response = client.get(url="/compute/v1/providers") + providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]] + return providers[0].provider_id diff --git a/src/firebolt/service/region.py b/src/firebolt/service/region.py index 6931bdf5cca..b5640556d4d 100644 --- a/src/firebolt/service/region.py +++ b/src/firebolt/service/region.py @@ -57,4 +57,8 @@ def get_by_key(self, region_key: RegionKey) -> Region: def get_by_id(self, region_id: str) -> Region: """Get an AWS Region by region_id.""" - return self.get_by_key(RegionKey(region_id=region_id)) + return self.get_by_key( + RegionKey( + provider_id=self.resource_manager.provider_id, region_id=region_id + ) + ) diff --git a/tests/conftest.py b/tests/conftest.py index 3d7053b0944..a65d03e9b66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,9 @@ from firebolt.common.settings import Settings from firebolt.model.instance_type import InstanceType, InstanceTypeKey +from firebolt.model.provider import Provider from firebolt.model.region import Region, RegionKey +from tests.util import list_to_paginated_response @pytest.fixture @@ -27,9 +29,23 @@ def access_token() -> str: @pytest.fixture -def region_1() -> Region: +def provider() -> Provider: + return Provider( + provider_id="mock_provider_id", + name="mock_provider_name", + ) + + +@pytest.fixture +def mock_providers(provider) -> list[Provider]: + return [provider] + + +@pytest.fixture +def region_1(provider) -> Region: return Region( key=RegionKey( + provider_id=provider.provider_id, region_id="mock_region_id_1", ), name="mock_region_1", @@ -37,9 +53,10 @@ def region_1() -> Region: @pytest.fixture -def region_2() -> Region: +def region_2(provider) -> Region: return Region( key=RegionKey( + provider_id=provider.provider_id, region_id="mock_region_id_2", ), name="mock_region_2", @@ -52,9 +69,10 @@ def mock_regions(region_1, region_2) -> list[Region]: @pytest.fixture -def instance_type_1(region_1) -> InstanceType: +def instance_type_1(provider, region_1) -> InstanceType: return InstanceType( key=InstanceTypeKey( + provider_id=provider.provider_id, region_id=region_1.key.region_id, instance_type_id="instance_type_id_1", ), @@ -63,9 +81,10 @@ def instance_type_1(region_1) -> InstanceType: @pytest.fixture -def instance_type_2(region_2) -> InstanceType: +def instance_type_2(provider, region_2) -> InstanceType: return InstanceType( key=InstanceTypeKey( + provider_id=provider.provider_id, region_id=region_2.key.region_id, instance_type_id="instance_type_id_2", ), @@ -108,6 +127,26 @@ def auth_url(settings: Settings) -> str: return f"https://{settings.server}/auth/v1/login" +@pytest.fixture +def provider_callback(provider_url: str, mock_providers) -> Callable: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + assert request.url == provider_url + return to_response( + status_code=httpx.codes.OK, + json=list_to_paginated_response(mock_providers), + ) + + return do_mock + + +@pytest.fixture +def provider_url(settings: Settings) -> str: + return f"https://{settings.server}/compute/v1/providers" + + @pytest.fixture def db_name() -> str: return "database" diff --git a/tests/model/test_instance_type.py b/tests/model/test_instance_type.py index 09a60639dfc..629589cb50e 100644 --- a/tests/model/test_instance_type.py +++ b/tests/model/test_instance_type.py @@ -12,9 +12,12 @@ def test_instance_type( httpx_mock: HTTPXMock, auth_callback: Callable, + provider_callback: Callable, settings: Settings, mock_instance_types: List[InstanceType], ): + httpx_mock.add_callback(auth_callback) + httpx_mock.add_callback(provider_callback) httpx_mock.add_callback(auth_callback) httpx_mock.add_response( url=f"https://{settings.server}/compute/v1/instanceTypes?page.first=5000", diff --git a/tests/model/test_region.py b/tests/model/test_region.py index 1275c649b28..f054a311e7e 100644 --- a/tests/model/test_region.py +++ b/tests/model/test_region.py @@ -12,9 +12,12 @@ def test_region( httpx_mock: HTTPXMock, auth_callback: Callable, + provider_callback: Callable, settings: Settings, mock_regions: List[Region], ): + httpx_mock.add_callback(auth_callback) + httpx_mock.add_callback(provider_callback) httpx_mock.add_callback(auth_callback) httpx_mock.add_response( url=f"https://{settings.server}/compute/v1/regions?page.first=5000",