diff --git a/changelog/+add_numberpool_support_protocols.added.md b/changelog/+add_numberpool_support_protocols.added.md new file mode 100644 index 00000000..aef27a24 --- /dev/null +++ b/changelog/+add_numberpool_support_protocols.added.md @@ -0,0 +1 @@ +add support for NumberPool attributes in generated protocols diff --git a/changelog/6882.fixed.md b/changelog/6882.fixed.md new file mode 100644 index 00000000..c0c8cebc --- /dev/null +++ b/changelog/6882.fixed.md @@ -0,0 +1 @@ +Fix value lookup using a flat notation like `foo__bar__value` with relationships of cardinality one \ No newline at end of file diff --git a/infrahub_sdk/node/node.py b/infrahub_sdk/node/node.py index f69c6231..242281b5 100644 --- a/infrahub_sdk/node/node.py +++ b/infrahub_sdk/node/node.py @@ -8,7 +8,7 @@ from ..exceptions import FeatureNotSupportedError, NodeNotFoundError, ResourceNotDefinedError, SchemaNotFoundError from ..graphql import Mutation, Query from ..schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind -from ..utils import compare_lists, generate_short_id, get_flat_value +from ..utils import compare_lists, generate_short_id from .attribute import Attribute from .constants import ( ARTIFACT_DEFINITION_GENERATE_FEATURE_NOT_SUPPORTED_MESSAGE, @@ -418,14 +418,6 @@ def generate_query_data_init( return data - def extract(self, params: dict[str, str]) -> dict[str, Any]: - """Extract some datapoints defined in a flat notation.""" - result: dict[str, Any] = {} - for key, value in params.items(): - result[key] = get_flat_value(self, key=value) - - return result - def __hash__(self) -> int: return hash(self.id) @@ -1036,6 +1028,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode: raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}") + async def get_flat_value(self, key: str, separator: str = "__") -> Any: + """Query recursively a value defined in a flat notation (string), on a hierarchy of objects + + Examples: + name__value + module.object.value + """ + if separator not in key: + return getattr(self, key) + + first, remaining = key.split(separator, maxsplit=1) + + if first in self._schema.attribute_names: + attr = getattr(self, first) + for part in remaining.split(separator): + attr = getattr(attr, part) + return attr + + try: + rel = self._schema.get_relationship(name=first) + except ValueError as exc: + raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc + + if rel.cardinality != RelationshipCardinality.ONE: + raise ValueError( + f"Can only look up flat value for relationships of cardinality {RelationshipCardinality.ONE.value}" + ) + + related_node: RelatedNode = getattr(self, first) + await related_node.fetch() + return await related_node.peer.get_flat_value(key=remaining, separator=separator) + + async def extract(self, params: dict[str, str]) -> dict[str, Any]: + """Extract some datapoints defined in a flat notation.""" + result: dict[str, Any] = {} + for key, value in params.items(): + result[key] = await self.get_flat_value(key=value) + + return result + def __dir__(self) -> Iterable[str]: base = list(super().__dir__()) return sorted( @@ -1622,6 +1654,46 @@ def _get_relationship_one(self, name: str) -> RelatedNode | RelatedNodeSync: raise ResourceNotDefinedError(message=f"The node doesn't have a cardinality=one relationship for {name}") + def get_flat_value(self, key: str, separator: str = "__") -> Any: + """Query recursively a value defined in a flat notation (string), on a hierarchy of objects + + Examples: + name__value + module.object.value + """ + if separator not in key: + return getattr(self, key) + + first, remaining = key.split(separator, maxsplit=1) + + if first in self._schema.attribute_names: + attr = getattr(self, first) + for part in remaining.split(separator): + attr = getattr(attr, part) + return attr + + try: + rel = self._schema.get_relationship(name=first) + except ValueError as exc: + raise ValueError(f"No attribute or relationship named '{first}' for '{self._schema.kind}'") from exc + + if rel.cardinality != RelationshipCardinality.ONE: + raise ValueError( + f"Can only look up flat value for relationships of cardinality {RelationshipCardinality.ONE.value}" + ) + + related_node: RelatedNodeSync = getattr(self, first) + related_node.fetch() + return related_node.peer.get_flat_value(key=remaining, separator=separator) + + def extract(self, params: dict[str, str]) -> dict[str, Any]: + """Extract some datapoints defined in a flat notation.""" + result: dict[str, Any] = {} + for key, value in params.items(): + result[key] = self.get_flat_value(key=value) + + return result + def __dir__(self) -> Iterable[str]: base = list(super().__dir__()) return sorted( diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index a47d95ef..a3daa1fb 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -204,8 +204,6 @@ def is_resource_pool(self) -> bool: ... def get_raw_graphql_data(self) -> dict | None: ... - def extract(self, params: dict[str, str]) -> dict[str, Any]: ... - @runtime_checkable class CoreNode(CoreNodeBase, Protocol): diff --git a/infrahub_sdk/protocols_generator/constants.py b/infrahub_sdk/protocols_generator/constants.py index d0bdb076..63c3dbb6 100644 --- a/infrahub_sdk/protocols_generator/constants.py +++ b/infrahub_sdk/protocols_generator/constants.py @@ -22,6 +22,7 @@ "List": "ListAttribute", "JSON": "JSONAttribute", "Any": "AnyAttribute", + "NumberPool": "Integer", } # The order of the classes in the list determines the order of the classes in the generated code diff --git a/infrahub_sdk/utils.py b/infrahub_sdk/utils.py index b505dbeb..9232b32d 100644 --- a/infrahub_sdk/utils.py +++ b/infrahub_sdk/utils.py @@ -190,23 +190,6 @@ def str_to_bool(value: str) -> bool: raise ValueError(f"{value} can not be converted into a boolean") from exc -def get_flat_value(obj: Any, key: str, separator: str = "__") -> Any: - """Query recursively an value defined in a flat notation (string), on a hierarchy of objects - - Examples: - name__value - module.object.value - """ - if separator not in key: - return getattr(obj, key) - - first_part, remaining_part = key.split(separator, maxsplit=1) - sub_obj = getattr(obj, first_part) - if not sub_obj: - return None - return get_flat_value(obj=sub_obj, key=remaining_part, separator=separator) - - def generate_request_filename(request: httpx.Request) -> str: """Return a filename for a request sent to the Infrahub API diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index b7f5eb38..4df1bd6a 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -1960,23 +1960,52 @@ async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_t @pytest.mark.parametrize("client_type", client_types) -async def test_node_extract(client, location_schema, location_data01, client_type): +async def test_get_flat_value( + httpx_mock: HTTPXMock, mock_schema_query_01, clients, location_schema, location_data01, client_type +): + httpx_mock.add_response( + method="POST", + json={"data": {"BuiltinTag": {"count": 1, "edges": [location_data01["node"]["primary_tag"]]}}}, + match_headers={"X-Infrahub-Tracker": "query-builtintag-page1"}, + is_reusable=True, + ) + if client_type == "standard": - node = InfrahubNode(client=client, schema=location_schema, data=location_data01) + tag = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01) + assert await tag.get_flat_value(key="name__value") == "DFW" + assert await tag.get_flat_value(key="primary_tag__display_label") == "red" + assert await tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red" + + with pytest.raises(ValueError, match="Can only look up flat value for relationships of cardinality one"): + assert await tag.get_flat_value(key="tags__display_label") == "red" else: - node = InfrahubNodeSync(client=client, schema=location_schema, data=location_data01) + tag = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01) + assert tag.get_flat_value(key="name__value") == "DFW" + assert tag.get_flat_value(key="primary_tag__display_label") == "red" + assert tag.get_flat_value(key="primary_tag.display_label", separator=".") == "red" - params = { - "identifier": "id", - "name": "name__value", - "description": "description__value", - } + with pytest.raises(ValueError, match="Can only look up flat value for relationships of cardinality one"): + assert tag.get_flat_value(key="tags__display_label") == "red" - assert node.extract(params=params) == { - "description": None, - "identifier": "llllllll-llll-llll-llll-llllllllllll", - "name": "DFW", - } + +@pytest.mark.parametrize("client_type", client_types) +async def test_node_extract(clients, location_schema, location_data01, client_type): + params = {"identifier": "id", "name": "name__value", "description": "description__value"} + if client_type == "standard": + node = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01) + assert await node.extract(params=params) == { + "description": None, + "identifier": "llllllll-llll-llll-llll-llllllllllll", + "name": "DFW", + } + + else: + node = InfrahubNodeSync(client=clients.sync, schema=location_schema, data=location_data01) + assert node.extract(params=params) == { + "description": None, + "identifier": "llllllll-llll-llll-llll-llllllllllll", + "name": "DFW", + } @pytest.mark.parametrize("client_type", client_types) diff --git a/tests/unit/sdk/test_utils.py b/tests/unit/sdk/test_utils.py index 88c25644..7a220cd2 100644 --- a/tests/unit/sdk/test_utils.py +++ b/tests/unit/sdk/test_utils.py @@ -6,7 +6,6 @@ from graphql import parse from whenever import Instant -from infrahub_sdk.node import InfrahubNode from infrahub_sdk.utils import ( base16decode, base16encode, @@ -19,7 +18,6 @@ duplicates, extract_fields, generate_short_id, - get_flat_value, is_valid_url, is_valid_uuid, str_to_bool, @@ -143,13 +141,6 @@ def test_base16(): assert base16decode(base16encode(1412823931503067241)) == 1412823931503067241 -def test_get_flat_value(client, tag_schema, tag_green_data): - tag = InfrahubNode(client=client, schema=tag_schema, data=tag_green_data) - assert get_flat_value(obj=tag, key="name__value") == "green" - assert get_flat_value(obj=tag, key="name__source__display_label") == "CRM" - assert get_flat_value(obj=tag, key="name.source.display_label", separator=".") == "CRM" - - def test_dict_hash(): assert dict_hash({"a": 1, "b": 2}) == "608de49a4600dbb5b173492759792e4a" assert dict_hash({"b": 2, "a": 1}) == "608de49a4600dbb5b173492759792e4a"