diff --git a/docs/docs/python-sdk/topics/object_file.mdx b/docs/docs/python-sdk/topics/object_file.mdx index c5033eb5..8de01740 100644 --- a/docs/docs/python-sdk/topics/object_file.mdx +++ b/docs/docs/python-sdk/topics/object_file.mdx @@ -60,12 +60,24 @@ apiVersion: infrahub.app/v1 kind: Object spec: kind: + strategy: # Optional, defaults to normal data: - [...] ``` > Multiple documents in a single YAML file are also supported, each document will be loaded separately. Documents are separated by `---` +### Data Processing Strategies + +The `strategy` field controls how the data in the object file is processed before loading into Infrahub: + +| Strategy | Description | Default | +|----------|-------------|---------| +| `normal` | No data manipulation is performed. Objects are loaded as-is. | Yes | +| `range_expand` | Range patterns (e.g., `[1-5]`) in string fields are expanded into multiple objects. | No | + +When `strategy` is not specified, it defaults to `normal`. + ### Relationship of cardinality one A relationship of cardinality one can either reference an existing node via its HFID or create a new node if it doesn't exist. @@ -198,7 +210,19 @@ Metadata support is planned for future releases. Currently, the Object file does ## Range Expansion in Object Files -The Infrahub Python SDK supports **range expansion** for string fields in object files. This feature allows you to specify a range pattern (e.g., `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing. +The Infrahub Python SDK supports **range expansion** for string fields in object files when the `strategy` is set to `range_expand`. This feature allows you to specify a range pattern (e.g., `[1-5]`) in any string value, and the SDK will automatically expand it into multiple objects during validation and processing. + +```yaml +--- +apiVersion: infrahub.app/v1 +kind: Object +spec: + kind: BuiltinLocation + strategy: range_expand # Enable range expansion + data: + - name: AMS[1-3] + type: Country +``` ### How Range Expansion Works @@ -213,6 +237,7 @@ The Infrahub Python SDK supports **range expansion** for string fields in object ```yaml spec: kind: BuiltinLocation + strategy: range_expand data: - name: AMS[1-3] type: Country @@ -234,6 +259,7 @@ This will expand to: ```yaml spec: kind: BuiltinLocation + strategy: range_expand data: - name: AMS[1-3] description: Datacenter [A-C] @@ -261,6 +287,7 @@ If you use ranges of different lengths in multiple fields: ```yaml spec: kind: BuiltinLocation + strategy: range_expand data: - name: AMS[1-3] description: "Datacenter [10-15]" diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 5bd54892..16992b1a 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -2,8 +2,9 @@ import copy import re +from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel, Field @@ -45,6 +46,11 @@ class RelationshipDataFormat(str, Enum): MANY_REF = "many_ref_list" +class ObjectStrategy(str, Enum): + NORMAL = "normal" + RANGE_EXPAND = "range_expand" + + class RelationshipInfo(BaseModel): name: str rel_schema: RelationshipSchema @@ -168,7 +174,7 @@ async def get_relationship_info( def expand_data_with_ranges(data: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Expand any item in self.data with range pattern in any value. Supports multiple fields, requires equal expansion length.""" + """Expand any item in data with range pattern in any value. Supports multiple fields, requires equal expansion length.""" range_pattern = re.compile(MATCH_PATTERN) expanded = [] for item in data: @@ -198,16 +204,69 @@ def expand_data_with_ranges(data: list[dict[str, Any]]) -> list[dict[str, Any]]: return expanded +class DataProcessor(ABC): + """Abstract base class for data processing strategies""" + + @abstractmethod + def process_data(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Process the data according to the strategy""" + + +class SingleDataProcessor(DataProcessor): + """Process data without any expansion""" + + def process_data(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: + return data + + +class RangeExpandDataProcessor(DataProcessor): + """Process data with range expansion""" + + def process_data(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: + return expand_data_with_ranges(data) + + +class DataProcessorFactory: + """Factory to create appropriate data processor based on strategy""" + + _processors: ClassVar[dict[ObjectStrategy, type[DataProcessor]]] = { + ObjectStrategy.NORMAL: SingleDataProcessor, + ObjectStrategy.RANGE_EXPAND: RangeExpandDataProcessor, + } + + @classmethod + def get_processor(cls, strategy: ObjectStrategy) -> DataProcessor: + processor_class = cls._processors.get(strategy) + if not processor_class: + raise ValueError( + f"Unknown strategy: {strategy} - no processor found. Valid strategies are: {list(cls._processors.keys())}" + ) + return processor_class() + + @classmethod + def register_processor(cls, strategy: ObjectStrategy, processor_class: type[DataProcessor]) -> None: + """Register a new processor for a strategy - useful for future extensions""" + cls._processors[strategy] = processor_class + + class InfrahubObjectFileData(BaseModel): kind: str + strategy: ObjectStrategy = ObjectStrategy.NORMAL data: list[dict[str, Any]] = Field(default_factory=list) + def _get_processed_data(self, data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Get data processed according to the strategy""" + processor = DataProcessorFactory.get_processor(self.strategy) + return processor.process_data(data) + async def validate_format(self, client: InfrahubClient, branch: str | None = None) -> list[ObjectValidationError]: errors: list[ObjectValidationError] = [] schema = await client.schema.get(kind=self.kind, branch=branch) - expanded_data = expand_data_with_ranges(self.data) - self.data = expanded_data - for idx, item in enumerate(expanded_data): + + processed_data = self._get_processed_data(data=self.data) + self.data = processed_data + + for idx, item in enumerate(processed_data): errors.extend( await self.validate_object( client=client, @@ -216,14 +275,16 @@ async def validate_format(self, client: InfrahubClient, branch: str | None = Non data=item, branch=branch, default_schema_kind=self.kind, + strategy=self.strategy, # Pass strategy down ) ) return errors async def process(self, client: InfrahubClient, branch: str | None = None) -> None: schema = await client.schema.get(kind=self.kind, branch=branch) - expanded_data = expand_data_with_ranges(self.data) - for idx, item in enumerate(expanded_data): + processed_data = self._get_processed_data(data=self.data) + + for idx, item in enumerate(processed_data): await self.create_node( client=client, schema=schema, @@ -243,6 +304,7 @@ async def validate_object( context: dict | None = None, branch: str | None = None, default_schema_kind: str | None = None, + strategy: ObjectStrategy = ObjectStrategy.NORMAL, ) -> list[ObjectValidationError]: errors: list[ObjectValidationError] = [] context = context.copy() if context else {} @@ -292,6 +354,7 @@ async def validate_object( context=context, branch=branch, default_schema_kind=default_schema_kind, + strategy=strategy, ) ) @@ -307,6 +370,7 @@ async def validate_related_nodes( context: dict | None = None, branch: str | None = None, default_schema_kind: str | None = None, + strategy: ObjectStrategy = ObjectStrategy.NORMAL, ) -> list[ObjectValidationError]: context = context.copy() if context else {} errors: list[ObjectValidationError] = [] @@ -348,7 +412,10 @@ async def validate_related_nodes( rel_info.find_matching_relationship(peer_schema=peer_schema) context.update(rel_info.get_context(value="placeholder")) - expanded_data = expand_data_with_ranges(data=data["data"]) + # Use strategy-aware data processing + processor = DataProcessorFactory.get_processor(strategy) + expanded_data = processor.process_data(data["data"]) + for idx, peer_data in enumerate(expanded_data): context["list_index"] = idx errors.extend( @@ -360,6 +427,7 @@ async def validate_related_nodes( context=context, branch=branch, default_schema_kind=default_schema_kind, + strategy=strategy, ) ) return errors @@ -633,14 +701,20 @@ class ObjectFile(InfrahubFile): @property def spec(self) -> InfrahubObjectFileData: if not self._spec: - self._spec = InfrahubObjectFileData(**self.data.spec) + try: + self._spec = InfrahubObjectFileData(**self.data.spec) + except Exception as exc: + raise ValidationError(identifier=str(self.location), message=str(exc)) return self._spec def validate_content(self) -> None: super().validate_content() if self.kind != InfrahubFileKind.OBJECT: raise ValueError("File is not an Infrahub Object file") - self._spec = InfrahubObjectFileData(**self.data.spec) + try: + self._spec = InfrahubObjectFileData(**self.data.spec) + except Exception as exc: + raise ValidationError(identifier=str(self.location), message=str(exc)) async def validate_format(self, client: InfrahubClient, branch: str | None = None) -> None: self.validate_content() diff --git a/infrahub_sdk/spec/range_expansion.py b/infrahub_sdk/spec/range_expansion.py index 441c589c..8f47b71f 100644 --- a/infrahub_sdk/spec/range_expansion.py +++ b/infrahub_sdk/spec/range_expansion.py @@ -1,7 +1,7 @@ import itertools import re -MATCH_PATTERN = r"(\[[\w,-]+\])" +MATCH_PATTERN = r"(\[[\w,-]*[-,][\w,-]*\])" def _escape_brackets(s: str) -> str: diff --git a/tests/unit/sdk/spec/test_object.py b/tests/unit/sdk/spec/test_object.py index dbe517ab..faf862b0 100644 --- a/tests/unit/sdk/spec/test_object.py +++ b/tests/unit/sdk/spec/test_object.py @@ -5,7 +5,7 @@ import pytest from infrahub_sdk.exceptions import ValidationError -from infrahub_sdk.spec.object import ObjectFile, RelationshipDataFormat, get_relationship_info +from infrahub_sdk.spec.object import ObjectFile, ObjectStrategy, RelationshipDataFormat, get_relationship_info if TYPE_CHECKING: from pytest_httpx import HTTPXMock @@ -40,6 +40,7 @@ def location_bad_syntax02(root_location: dict) -> dict: data = [{"name": "Mexico", "notvalidattribute": "notvalidattribute", "type": "Country"}] location = root_location.copy() location["spec"]["data"] = data + location["spec"]["strategy"] = ObjectStrategy.RANGE_EXPAND return location @@ -53,6 +54,21 @@ def location_expansion(root_location: dict) -> dict: ] location = root_location.copy() location["spec"]["data"] = data + location["spec"]["strategy"] = ObjectStrategy.RANGE_EXPAND + return location + + +@pytest.fixture +def no_location_expansion(root_location: dict) -> dict: + data = [ + { + "name": "AMS[1-5]", + "type": "Country", + } + ] + location = root_location.copy() + location["spec"]["data"] = data + location["spec"]["strategy"] = ObjectStrategy.NORMAL return location @@ -67,6 +83,7 @@ def location_expansion_multiple_ranges(root_location: dict) -> dict: ] location = root_location.copy() location["spec"]["data"] = data + location["spec"]["strategy"] = ObjectStrategy.RANGE_EXPAND return location @@ -81,6 +98,7 @@ def location_expansion_multiple_ranges_bad_syntax(root_location: dict) -> dict: ] location = root_location.copy() location["spec"]["data"] = data + location["spec"]["strategy"] = ObjectStrategy.RANGE_EXPAND return location @@ -123,6 +141,17 @@ async def test_validate_object_expansion( assert obj.spec.data[4]["name"] == "AMS5" +async def test_validate_no_object_expansion( + client: InfrahubClient, mock_schema_query_01: HTTPXMock, no_location_expansion +) -> None: + obj = ObjectFile(location="some/path", content=no_location_expansion) + await obj.validate_format(client=client) + assert obj.spec.kind == "BuiltinLocation" + assert obj.spec.strategy == ObjectStrategy.NORMAL + assert len(obj.spec.data) == 1 + assert obj.spec.data[0]["name"] == "AMS[1-5]" + + async def test_validate_object_expansion_multiple_ranges( client: InfrahubClient, mock_schema_query_01: HTTPXMock, location_expansion_multiple_ranges ) -> None: @@ -199,3 +228,30 @@ async def test_get_relationship_info_tags( rel_info = await get_relationship_info(client, location_schema, "tags", data) assert rel_info.is_valid == is_valid assert rel_info.format == format + + +async def test_invalid_object_expansion_processor( + client: InfrahubClient, mock_schema_query_01: HTTPXMock, location_expansion +) -> None: + obj = ObjectFile(location="some/path", content=location_expansion) + + from infrahub_sdk.spec.object import DataProcessorFactory, ObjectStrategy # noqa: PLC0415 + + # Patch _processors to remove the invalid strategy + original_processors = DataProcessorFactory._processors.copy() + try: + DataProcessorFactory._processors[ObjectStrategy.RANGE_EXPAND] = None + with pytest.raises(ValueError) as exc: + await obj.validate_format(client=client) + assert "Unknown strategy" in str(exc.value) + finally: + DataProcessorFactory._processors = original_processors + + +async def test_invalid_object_expansion_strategy(client: InfrahubClient, location_expansion) -> None: + location_expansion["spec"]["strategy"] = "InvalidStrategy" + obj = ObjectFile(location="some/path", content=location_expansion) + + with pytest.raises(ValidationError) as exc: + await obj.validate_format(client=client) + assert "Input should be" in str(exc.value) diff --git a/tests/unit/sdk/test_range_expansion.py b/tests/unit/sdk/test_range_expansion.py index 26d817c0..360992d5 100644 --- a/tests/unit/sdk/test_range_expansion.py +++ b/tests/unit/sdk/test_range_expansion.py @@ -60,7 +60,7 @@ def test_mixed_range_expansion() -> None: def test_single_value_in_brackets() -> None: - assert range_expansion("Device[5]") == ["Device5"] + assert range_expansion("Device[5]") == ["Device[5]"] def test_empty_brackets() -> None: @@ -82,10 +82,6 @@ def test_duplicate_and_overlapping_values() -> None: assert range_expansion("Device[1,1,2]") == ["Device1", "Device1", "Device2"] -def test_whitespace_handling() -> None: - assert range_expansion("Device[ 1 - 3 ]") == ["Device[ 1 - 3 ]"] - - def test_descending_ranges() -> None: assert range_expansion("Device[3-1]") == ["Device3", "Device2", "Device1"] @@ -104,3 +100,7 @@ def test_unicode_ranges() -> None: def test_brackets_in_strings() -> None: assert range_expansion(r"Service Object [Circuit Provider, X]") == ["Service Object [Circuit Provider, X]"] + + +def test_words_in_brackets() -> None: + assert range_expansion("Device[expansion]") == ["Device[expansion]"]