Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/firebolt/common/constants.py

This file was deleted.

3 changes: 1 addition & 2 deletions src/firebolt/model/instance_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

from pydantic import Field

from firebolt.common.constants import AWS_PROVIDER_ID
from firebolt.model import FireboltBaseModel


class InstanceTypeKey(FireboltBaseModel, frozen=True): # type: ignore
provider_id: str = AWS_PROVIDER_ID
provider_id: str
region_id: str
instance_type_id: str

Expand Down
16 changes: 16 additions & 0 deletions src/firebolt/model/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from datetime import datetime
from typing import Optional

from pydantic import Field

from firebolt.model import FireboltBaseModel


class Provider(FireboltBaseModel, frozen=True): # type: ignore
provider_id: str = Field(alias="id")
name: str

# optional
create_time: Optional[datetime]
display_name: Optional[str]
last_update_time: Optional[datetime]
3 changes: 1 addition & 2 deletions src/firebolt/model/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

from pydantic import Field

from firebolt.common.constants import AWS_PROVIDER_ID
from firebolt.model import FireboltBaseModel


class RegionKey(FireboltBaseModel, frozen=True): # type: ignore
provider_id: str = AWS_PROVIDER_ID
provider_id: str
region_id: str


Expand Down
2 changes: 2 additions & 0 deletions src/firebolt/service/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from firebolt.client import Client, log_request, log_response, raise_on_4xx_5xx
from firebolt.common import Settings
from firebolt.service.provider import get_provider_id


class ResourceManager:
Expand Down Expand Up @@ -47,6 +48,7 @@ def _init_services(self, default_region_name: str) -> None:
resource_manager=self, default_region_name=default_region_name
)
self.instance_types = InstanceTypeService(resource_manager=self)
self.provider_id = get_provider_id(client=self.client)

# Firebolt Resources
self.databases = DatabaseService(resource_manager=self)
Expand Down
9 changes: 9 additions & 0 deletions src/firebolt/service/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from firebolt.client import Client
from firebolt.model.provider import Provider


def get_provider_id(client: Client) -> str:
"""Get the AWS provider_id."""
response = client.get(url="/compute/v1/providers")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add error handling here. Check that json doesn't contain "error"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be handled by this hook: https://github.com/firebolt-analytics/firebolt-sdk/blob/189f34a8c4f1447282ceefd49fb7d547808dbd84/src/firebolt/client/hooks.py#L26

Looking at the docs: https://api.dev.firebolt.io/devDocs#operation/computeV1ListProviders

The hook parses out the "message" associated with an error. However, it currently does not examine the "error" field.

This kind of error handling needs to happen on pretty much all requests, so let's go over it in our Wednesday meeting to ensure we've got a good pattern, then I will apply it everywhere in a follow-on PR.

providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]]
return providers[0].provider_id
6 changes: 5 additions & 1 deletion src/firebolt/service/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ def get_by_key(self, region_key: RegionKey) -> Region:

def get_by_id(self, region_id: str) -> Region:
"""Get an AWS Region by region_id."""
return self.get_by_key(RegionKey(region_id=region_id))
return self.get_by_key(
RegionKey(
provider_id=self.resource_manager.provider_id, region_id=region_id
)
)
47 changes: 43 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from firebolt.common.settings import Settings
from firebolt.model.instance_type import InstanceType, InstanceTypeKey
from firebolt.model.provider import Provider
from firebolt.model.region import Region, RegionKey
from tests.util import list_to_paginated_response


@pytest.fixture
Expand All @@ -27,19 +29,34 @@ def access_token() -> str:


@pytest.fixture
def region_1() -> Region:
def provider() -> Provider:
return Provider(
provider_id="mock_provider_id",
name="mock_provider_name",
)


@pytest.fixture
def mock_providers(provider) -> list[Provider]:
return [provider]


@pytest.fixture
def region_1(provider) -> Region:
return Region(
key=RegionKey(
provider_id=provider.provider_id,
region_id="mock_region_id_1",
),
name="mock_region_1",
)


@pytest.fixture
def region_2() -> Region:
def region_2(provider) -> Region:
return Region(
key=RegionKey(
provider_id=provider.provider_id,
region_id="mock_region_id_2",
),
name="mock_region_2",
Expand All @@ -52,9 +69,10 @@ def mock_regions(region_1, region_2) -> list[Region]:


@pytest.fixture
def instance_type_1(region_1) -> InstanceType:
def instance_type_1(provider, region_1) -> InstanceType:
return InstanceType(
key=InstanceTypeKey(
provider_id=provider.provider_id,
region_id=region_1.key.region_id,
instance_type_id="instance_type_id_1",
),
Expand All @@ -63,9 +81,10 @@ def instance_type_1(region_1) -> InstanceType:


@pytest.fixture
def instance_type_2(region_2) -> InstanceType:
def instance_type_2(provider, region_2) -> InstanceType:
return InstanceType(
key=InstanceTypeKey(
provider_id=provider.provider_id,
region_id=region_2.key.region_id,
instance_type_id="instance_type_id_2",
),
Expand Down Expand Up @@ -108,6 +127,26 @@ def auth_url(settings: Settings) -> str:
return f"https://{settings.server}/auth/v1/login"


@pytest.fixture
def provider_callback(provider_url: str, mock_providers) -> Callable:
def do_mock(
request: httpx.Request = None,
**kwargs,
) -> Response:
assert request.url == provider_url
return to_response(
status_code=httpx.codes.OK,
json=list_to_paginated_response(mock_providers),
)

return do_mock


@pytest.fixture
def provider_url(settings: Settings) -> str:
return f"https://{settings.server}/compute/v1/providers"


@pytest.fixture
def db_name() -> str:
return "database"
3 changes: 3 additions & 0 deletions tests/model/test_instance_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
def test_instance_type(
httpx_mock: HTTPXMock,
auth_callback: Callable,
provider_callback: Callable,
settings: Settings,
mock_instance_types: List[InstanceType],
):
httpx_mock.add_callback(auth_callback)
httpx_mock.add_callback(provider_callback)
httpx_mock.add_callback(auth_callback)
httpx_mock.add_response(
url=f"https://{settings.server}/compute/v1/instanceTypes?page.first=5000",
Expand Down
3 changes: 3 additions & 0 deletions tests/model/test_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
def test_region(
httpx_mock: HTTPXMock,
auth_callback: Callable,
provider_callback: Callable,
settings: Settings,
mock_regions: List[Region],
):
httpx_mock.add_callback(auth_callback)
httpx_mock.add_callback(provider_callback)
httpx_mock.add_callback(auth_callback)
httpx_mock.add_response(
url=f"https://{settings.server}/compute/v1/regions?page.first=5000",
Expand Down