Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 63 additions & 9 deletions src/firebolt/model/V1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,70 @@
import json
from typing import Any
from typing import Any, Callable, Final, Type, TypeVar, Union

from pydantic import BaseModel
import pydantic
from typing_extensions import TypeAlias

Model = TypeVar("Model", bound="pydantic.BaseModel")

class FireboltBaseModel(BaseModel):
GenericCallable: TypeAlias = Callable[..., Any]

# Using Pydantic 1.* config class for backwards compatibility
class Config:
extra = "forbid"
allow_population_by_field_name = True # Pydantic 1.8
populate_by_name = True # Pydantic 2.0
# Using `.VERSION` instead of `.__version__` for backward compatibility:
PYDANTIC_VERSION: Final[int] = int(pydantic.VERSION[0])


def use_if_version_ge(
version_ge: int,
obj: Union[pydantic.BaseModel, Type[Model]],
previous_method: str,
latest_method: str,
) -> GenericCallable:
"""
Utility function to get desired method from base model.

Args:
version_ge: The version number that will be used to determine
the desired method.
obj: The object on which the method will be taken from
previous_method: The method previously available in a version
smaller than `version_ge`.
latest_method: The method available from `version_ge` onwards.

"""
if PYDANTIC_VERSION >= version_ge:
return getattr(obj, latest_method)
else:
return getattr(obj, previous_method)


if PYDANTIC_VERSION >= 2:
# This import can only happen outside the BaseModel,
# or it will raise PydanticUserError
from pydantic import ConfigDict


class FireboltBaseModel(pydantic.BaseModel):
if PYDANTIC_VERSION >= 2:
# Pydantic V2 config
model_config = ConfigDict(populate_by_name=True, from_attributes=True)

else:
# Using Pydantic 1.* config class for backwards compatibility
class Config:
extra = "forbid"
allow_population_by_field_name = True # Pydantic 1.8

def model_dict(self, *args: Any, **kwargs: Any) -> dict:
"""Pydantic V2 and V1 compatible method for `dict` -> `model_dump`."""
return use_if_version_ge(2, self, "dict", "model_dump")(*args, **kwargs)

@classmethod
def parse_model(cls: Type[Model], *args: Any, **kwargs: Any) -> Model:
"""Pydantic V2 and V1 compatible method for `parse_obj` -> `model_validate`."""
return use_if_version_ge(2, cls, "parse_obj", "model_validate")(*args, **kwargs)

def model_json(self, *args: Any, **kwargs: Any) -> str:
"""Pydantic V2 and V1 compatible method for `json` -> `model_dump_json`."""
return use_if_version_ge(2, self, "json", "model_dump_json")(*args, **kwargs)

def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict:
"""
Expand All @@ -24,4 +78,4 @@ def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict:
expects to take in a dictionary of primitives as input to the JSON parameter
of its request function. See: https://www.python-httpx.org/api/#helper-functions
"""
return json.loads(self.json(*args, **kwargs))
return json.loads(self.model_json(*args, **kwargs))
2 changes: 1 addition & 1 deletion src/firebolt/model/V1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Database(FireboltBaseModel):
def parse_obj_with_service(
cls, obj: Any, database_service: DatabaseService
) -> Database:
database = cls.parse_obj(obj)
database = cls.parse_model(obj)
database._service = database_service
return database

Expand Down
2 changes: 1 addition & 1 deletion src/firebolt/model/V1/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Engine(FireboltBaseModel):

@classmethod
def parse_obj_with_service(cls, obj: Any, engine_service: EngineService) -> Engine:
engine = cls.parse_obj(obj)
engine = cls.parse_model(obj)
engine._service = engine_service
return engine

Expand Down
10 changes: 5 additions & 5 deletions src/firebolt/service/V1/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_by_key(self, binding_key: BindingKey) -> Binding:
)
)
binding: dict = response.json()["binding"]
return Binding.parse_obj(binding)
return Binding.parse_model(binding)

def get_many(
self,
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_many(
}
),
)
return [Binding.parse_obj(i["node"]) for i in response.json()["edges"]]
return [Binding.parse_model(i["node"]) for i in response.json()["edges"]]

def get_database_bound_to_engine(self, engine: Engine) -> Optional[Database]:
"""Get the database to which an engine is bound, if any."""
Expand Down Expand Up @@ -127,12 +127,12 @@ def create(
)
assert database.database_id is not None, "Database must have database_id"
binding = Binding(
binding_key=BindingKey(
binding_key=BindingKey( # type: ignore[call-arg]
account_id=self.account_id,
database_id=database.database_id,
engine_id=engine.engine_id,
),
is_default_engine=is_default_engine,
is_default_engine=is_default_engine, # type: ignore[call-arg]
)

response = self.client.post(
Expand All @@ -145,4 +145,4 @@ def create(
by_alias=True, include={"binding_key": ..., "is_default_engine": ...}
),
)
return Binding.parse_obj(response.json()["binding"])
return Binding.parse_model(response.json()["binding"])
4 changes: 3 additions & 1 deletion src/firebolt/service/V1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ class _DatabaseCreateRequest(FireboltBaseModel):
else:
region_key = self.resource_manager.regions.get_by_name(name=region).key
database = Database(
name=name, compute_region_key=region_key, description=description
name=name,
compute_region_key=region_key, # type: ignore[call-arg]
description=description,
)

logger.info(f"Creating Database (name={name})")
Expand Down
2 changes: 1 addition & 1 deletion src/firebolt/service/V1/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
def get_provider_id(client: Client) -> str:
"""Get the AWS provider_id."""
response = client.get(url=PROVIDERS_URL)
providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]]
providers = [Provider.parse_model(i["node"]) for i in response.json()["edges"]]
return providers[0].provider_id
2 changes: 1 addition & 1 deletion src/firebolt/service/V1/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def regions(self) -> List[Region]:
"""List of available AWS regions on Firebolt."""

response = self.client.get(url=REGIONS_URL, params={"page.first": 5000})
return [Region.parse_obj(i["node"]) for i in response.json()["edges"]]
return [Region.parse_model(i["node"]) for i in response.json()["edges"]]

@cached_property
def regions_by_name(self) -> Dict[str, Region]:
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/service/V1/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def do_mock(
assert urlparse(engine_url).path in request.url.path
return Response(
status_code=httpx.codes.OK,
json={"engine": mock_engine.dict()},
json={"engine": mock_engine.model_dict()},
)

return do_mock
Expand All @@ -254,7 +254,7 @@ def do_mock(
assert request.url == account_engine_url
return Response(
status_code=httpx.codes.OK,
json={"engine": mock_engine.dict()},
json={"engine": mock_engine.model_dict()},
)

return do_mock
Expand Down Expand Up @@ -297,7 +297,7 @@ def do_mock(
assert request.url == databases_url
return Response(
status_code=httpx.codes.OK,
json={"database": mock_database.dict()},
json={"database": mock_database.model_dict()},
)

return do_mock
Expand All @@ -309,7 +309,8 @@ def get_databases_callback_inner(
request: httpx.Request = None, **kwargs
) -> Response:
return Response(
status_code=httpx.codes.OK, json={"edges": [{"node": mock_database.dict()}]}
status_code=httpx.codes.OK,
json={"edges": [{"node": mock_database.model_dict()}]},
)

return get_databases_callback_inner
Expand All @@ -329,7 +330,7 @@ def do_mock(
assert request.url == database_url
return Response(
status_code=httpx.codes.OK,
json={"database": mock_database.dict()},
json={"database": mock_database.model_dict()},
)

return do_mock
Expand Down Expand Up @@ -407,7 +408,7 @@ def do_mock(
assert request.url == database_get_url
return Response(
status_code=httpx.codes.OK,
json={"database": mock_database.dict()},
json={"database": mock_database.model_dict()},
)

return do_mock
Expand Down Expand Up @@ -474,7 +475,7 @@ def do_mock(
assert request.url == create_binding_url
return Response(
status_code=httpx.codes.OK,
json={"binding": binding.dict()},
json={"binding": binding.model_dict()},
)

return do_mock
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/service/V1/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_create_binding(
httpx_mock.add_callback(account_id_callback, url=account_id_url)
httpx_mock.add_response(url=bindings_url, method="GET", json={"edges": []})
httpx_mock.add_response(
url=create_binding_url, method="POST", json={"binding": binding.dict()}
url=create_binding_url, method="POST", json={"binding": binding.model_dict()}
)

resource_manager = ResourceManager(settings=settings)
Expand Down Expand Up @@ -116,7 +116,9 @@ def test_get_engines_bound_to_db(
httpx_mock.add_callback(account_id_callback, url=account_id_url)
httpx_mock.add_callback(bindings_database_callback, url=database_bindings_url)
httpx_mock.add_response(
url=engines_by_id_url, method="POST", json={"engines": [mock_engine.dict()]}
url=engines_by_id_url,
method="POST",
json={"engines": [mock_engine.model_dict()]},
)

resource_manager = ResourceManager(settings=settings)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def list_to_paginated_response(items: List[FireboltBaseModel]) -> Dict:


def list_to_paginated_response_v1(items: List[FireboltBaseModelV1]) -> Dict:
return {"edges": [{"node": i.dict()} for i in items]}
return {"edges": [{"node": i.model_dict()} for i in items]}


def execute_generator_requests(
Expand Down