diff --git a/src/firebolt/model/V1/__init__.py b/src/firebolt/model/V1/__init__.py index 1a1114c232..23123723c3 100644 --- a/src/firebolt/model/V1/__init__.py +++ b/src/firebolt/model/V1/__init__.py @@ -1,16 +1,70 @@ import json -from typing import Any +from typing import Any, Callable, Final, Type, TypeVar, Union -from pydantic import BaseModel +import pydantic +from typing_extensions import TypeAlias +Model = TypeVar("Model", bound="pydantic.BaseModel") -class FireboltBaseModel(BaseModel): +GenericCallable: TypeAlias = Callable[..., Any] - # Using Pydantic 1.* config class for backwards compatibility - class Config: - extra = "forbid" - allow_population_by_field_name = True # Pydantic 1.8 - populate_by_name = True # Pydantic 2.0 +# Using `.VERSION` instead of `.__version__` for backward compatibility: +PYDANTIC_VERSION: Final[int] = int(pydantic.VERSION[0]) + + +def use_if_version_ge( + version_ge: int, + obj: Union[pydantic.BaseModel, Type[Model]], + previous_method: str, + latest_method: str, +) -> GenericCallable: + """ + Utility function to get desired method from base model. + + Args: + version_ge: The version number that will be used to determine + the desired method. + obj: The object on which the method will be taken from + previous_method: The method previously available in a version + smaller than `version_ge`. + latest_method: The method available from `version_ge` onwards. + + """ + if PYDANTIC_VERSION >= version_ge: + return getattr(obj, latest_method) + else: + return getattr(obj, previous_method) + + +if PYDANTIC_VERSION >= 2: + # This import can only happen outside the BaseModel, + # or it will raise PydanticUserError + from pydantic import ConfigDict + + +class FireboltBaseModel(pydantic.BaseModel): + if PYDANTIC_VERSION >= 2: + # Pydantic V2 config + model_config = ConfigDict(populate_by_name=True, from_attributes=True) + + else: + # Using Pydantic 1.* config class for backwards compatibility + class Config: + extra = "forbid" + allow_population_by_field_name = True # Pydantic 1.8 + + def model_dict(self, *args: Any, **kwargs: Any) -> dict: + """Pydantic V2 and V1 compatible method for `dict` -> `model_dump`.""" + return use_if_version_ge(2, self, "dict", "model_dump")(*args, **kwargs) + + @classmethod + def parse_model(cls: Type[Model], *args: Any, **kwargs: Any) -> Model: + """Pydantic V2 and V1 compatible method for `parse_obj` -> `model_validate`.""" + return use_if_version_ge(2, cls, "parse_obj", "model_validate")(*args, **kwargs) + + def model_json(self, *args: Any, **kwargs: Any) -> str: + """Pydantic V2 and V1 compatible method for `json` -> `model_dump_json`.""" + return use_if_version_ge(2, self, "json", "model_dump_json")(*args, **kwargs) def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict: """ @@ -24,4 +78,4 @@ def jsonable_dict(self, *args: Any, **kwargs: Any) -> dict: expects to take in a dictionary of primitives as input to the JSON parameter of its request function. See: https://www.python-httpx.org/api/#helper-functions """ - return json.loads(self.json(*args, **kwargs)) + return json.loads(self.model_json(*args, **kwargs)) diff --git a/src/firebolt/model/V1/database.py b/src/firebolt/model/V1/database.py index 6ad19c1ef6..33b0eb2aa5 100644 --- a/src/firebolt/model/V1/database.py +++ b/src/firebolt/model/V1/database.py @@ -65,7 +65,7 @@ class Database(FireboltBaseModel): def parse_obj_with_service( cls, obj: Any, database_service: DatabaseService ) -> Database: - database = cls.parse_obj(obj) + database = cls.parse_model(obj) database._service = database_service return database diff --git a/src/firebolt/model/V1/engine.py b/src/firebolt/model/V1/engine.py index 4a97eb4c4c..458f1db089 100644 --- a/src/firebolt/model/V1/engine.py +++ b/src/firebolt/model/V1/engine.py @@ -127,7 +127,7 @@ class Engine(FireboltBaseModel): @classmethod def parse_obj_with_service(cls, obj: Any, engine_service: EngineService) -> Engine: - engine = cls.parse_obj(obj) + engine = cls.parse_model(obj) engine._service = engine_service return engine diff --git a/src/firebolt/service/V1/binding.py b/src/firebolt/service/V1/binding.py index f21b09cea9..f141ebee6c 100644 --- a/src/firebolt/service/V1/binding.py +++ b/src/firebolt/service/V1/binding.py @@ -28,7 +28,7 @@ def get_by_key(self, binding_key: BindingKey) -> Binding: ) ) binding: dict = response.json()["binding"] - return Binding.parse_obj(binding) + return Binding.parse_model(binding) def get_many( self, @@ -66,7 +66,7 @@ def get_many( } ), ) - return [Binding.parse_obj(i["node"]) for i in response.json()["edges"]] + return [Binding.parse_model(i["node"]) for i in response.json()["edges"]] def get_database_bound_to_engine(self, engine: Engine) -> Optional[Database]: """Get the database to which an engine is bound, if any.""" @@ -127,12 +127,12 @@ def create( ) assert database.database_id is not None, "Database must have database_id" binding = Binding( - binding_key=BindingKey( + binding_key=BindingKey( # type: ignore[call-arg] account_id=self.account_id, database_id=database.database_id, engine_id=engine.engine_id, ), - is_default_engine=is_default_engine, + is_default_engine=is_default_engine, # type: ignore[call-arg] ) response = self.client.post( @@ -145,4 +145,4 @@ def create( by_alias=True, include={"binding_key": ..., "is_default_engine": ...} ), ) - return Binding.parse_obj(response.json()["binding"]) + return Binding.parse_model(response.json()["binding"]) diff --git a/src/firebolt/service/V1/database.py b/src/firebolt/service/V1/database.py index 2673f0a740..4ff7523f19 100644 --- a/src/firebolt/service/V1/database.py +++ b/src/firebolt/service/V1/database.py @@ -110,7 +110,9 @@ class _DatabaseCreateRequest(FireboltBaseModel): else: region_key = self.resource_manager.regions.get_by_name(name=region).key database = Database( - name=name, compute_region_key=region_key, description=description + name=name, + compute_region_key=region_key, # type: ignore[call-arg] + description=description, ) logger.info(f"Creating Database (name={name})") diff --git a/src/firebolt/service/V1/provider.py b/src/firebolt/service/V1/provider.py index 2d9ddec269..fa53e8129d 100644 --- a/src/firebolt/service/V1/provider.py +++ b/src/firebolt/service/V1/provider.py @@ -6,5 +6,5 @@ def get_provider_id(client: Client) -> str: """Get the AWS provider_id.""" response = client.get(url=PROVIDERS_URL) - providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]] + providers = [Provider.parse_model(i["node"]) for i in response.json()["edges"]] return providers[0].provider_id diff --git a/src/firebolt/service/V1/region.py b/src/firebolt/service/V1/region.py index f688e26fa3..8191a7f494 100644 --- a/src/firebolt/service/V1/region.py +++ b/src/firebolt/service/V1/region.py @@ -25,7 +25,7 @@ def regions(self) -> List[Region]: """List of available AWS regions on Firebolt.""" response = self.client.get(url=REGIONS_URL, params={"page.first": 5000}) - return [Region.parse_obj(i["node"]) for i in response.json()["edges"]] + return [Region.parse_model(i["node"]) for i in response.json()["edges"]] @cached_property def regions_by_name(self) -> Dict[str, Region]: diff --git a/tests/unit/service/V1/conftest.py b/tests/unit/service/V1/conftest.py index d65f7f9b1c..bc0c409add 100644 --- a/tests/unit/service/V1/conftest.py +++ b/tests/unit/service/V1/conftest.py @@ -234,7 +234,7 @@ def do_mock( assert urlparse(engine_url).path in request.url.path return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json={"engine": mock_engine.model_dict()}, ) return do_mock @@ -254,7 +254,7 @@ def do_mock( assert request.url == account_engine_url return Response( status_code=httpx.codes.OK, - json={"engine": mock_engine.dict()}, + json={"engine": mock_engine.model_dict()}, ) return do_mock @@ -297,7 +297,7 @@ def do_mock( assert request.url == databases_url return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json={"database": mock_database.model_dict()}, ) return do_mock @@ -309,7 +309,8 @@ def get_databases_callback_inner( request: httpx.Request = None, **kwargs ) -> Response: return Response( - status_code=httpx.codes.OK, json={"edges": [{"node": mock_database.dict()}]} + status_code=httpx.codes.OK, + json={"edges": [{"node": mock_database.model_dict()}]}, ) return get_databases_callback_inner @@ -329,7 +330,7 @@ def do_mock( assert request.url == database_url return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json={"database": mock_database.model_dict()}, ) return do_mock @@ -407,7 +408,7 @@ def do_mock( assert request.url == database_get_url return Response( status_code=httpx.codes.OK, - json={"database": mock_database.dict()}, + json={"database": mock_database.model_dict()}, ) return do_mock @@ -474,7 +475,7 @@ def do_mock( assert request.url == create_binding_url return Response( status_code=httpx.codes.OK, - json={"binding": binding.dict()}, + json={"binding": binding.model_dict()}, ) return do_mock diff --git a/tests/unit/service/V1/test_bindings.py b/tests/unit/service/V1/test_bindings.py index b4c3bf6054..95328520a8 100644 --- a/tests/unit/service/V1/test_bindings.py +++ b/tests/unit/service/V1/test_bindings.py @@ -56,7 +56,7 @@ def test_create_binding( httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_response(url=bindings_url, method="GET", json={"edges": []}) httpx_mock.add_response( - url=create_binding_url, method="POST", json={"binding": binding.dict()} + url=create_binding_url, method="POST", json={"binding": binding.model_dict()} ) resource_manager = ResourceManager(settings=settings) @@ -116,7 +116,9 @@ def test_get_engines_bound_to_db( httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(bindings_database_callback, url=database_bindings_url) httpx_mock.add_response( - url=engines_by_id_url, method="POST", json={"engines": [mock_engine.dict()]} + url=engines_by_id_url, + method="POST", + json={"engines": [mock_engine.model_dict()]}, ) resource_manager = ResourceManager(settings=settings) diff --git a/tests/unit/util.py b/tests/unit/util.py index 0cf3fbc11a..df5b672ac9 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -23,7 +23,7 @@ def list_to_paginated_response(items: List[FireboltBaseModel]) -> Dict: def list_to_paginated_response_v1(items: List[FireboltBaseModelV1]) -> Dict: - return {"edges": [{"node": i.dict()} for i in items]} + return {"edges": [{"node": i.model_dict()} for i in items]} def execute_generator_requests(