diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e296b4c9..0b0bd593 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -91,7 +91,7 @@ jobs: - name: "Check out repository code" uses: "actions/checkout@v4" - name: "Linting: markdownlint" - uses: DavidAnson/markdownlint-cli2-action@v18 + uses: DavidAnson/markdownlint-cli2-action@v19 with: config: .markdownlint.yaml globs: | diff --git a/CHANGELOG.md b/CHANGELOG.md index c2f8fe0a..1e004a12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,50 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [1.3.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.3.0) - 2024-12-30 + +### Added + +#### Testing library (**Alpha**) + +A new collection of tools and utilities to help with testing is available under `infrahub_sdk.testing`. + +The first component available is a `TestInfrahubDockerClient`, a pytest Class designed to help creating integration tests based on Infrahub. See a simple example below to help you get started. + +> the installation of `infrahub-testcontainers` is required + +```python +import pytest + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.testing.docker import TestInfrahubDockerClient + +class TestInfrahubNode(TestInfrahubDockerClient): + + @pytest.fixture(scope="class") + def infrahub_version(self) -> str: + """Required (for now) to define the version of infrahub to use.""" + return "1.0.10" + + @pytest.fixture(scope="class") + async def test_create_tag(self, default_branch: str, client: InfrahubClient) -> None: + obj = await client.create(kind="BuiltinTag", name="Blue") + await obj.save() + assert obj.id +``` + +### Changed + +- The Pydantic models for the schema have been split into multiple versions to align better with the different phase of the lifecycle of the schema. + - User input: includes only the options available for a user to define (NodeSchema, AttributeSchema, RelationshipSchema, GenericSchema) + - API: Format of the schema as exposed by the API in infrahub with some read only settings (NodeSchemaAPI, AttributeSchemaAPI, RelationshipSchemaAPI, GenericSchemaAPI) + +### Fixed + +- Fix behaviour of attribute value coming from resource pools for async client ([#66](https://github.com/opsmill/infrahub-sdk-python/issues/66)) +- Convert import_root to a string if it was submitted as a Path object to ensure that anything added to sys.path is a string +- Fix relative imports for the pytest plugin, note that the relative imports can't be at the top level of the repository alongside .infrahub.yml. They have to be located within a subfolder. ([#166](https://github.com/opsmill/infrahub-sdk-python/issues/166)) + ## [1.2.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.2.0) - 2024-12-19 ### Added @@ -60,7 +104,7 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang ### Removed -- Breaking change: Removed all exports from infrahub_sdk/__init__.py except InfrahubClient, InfrahubClientSync and Config. If you previously imported other classes such as InfrahubNode from the root level these need to change to instead be an absolute path. +- Breaking change: Removed all exports from `infrahub_sdk/__init__.py` except InfrahubClient, InfrahubClientSync and Config. If you previously imported other classes such as InfrahubNode from the root level these need to change to instead be an absolute path. ### Added diff --git a/infrahub_sdk/_importer.py b/infrahub_sdk/_importer.py index 071388d6..b45beb85 100644 --- a/infrahub_sdk/_importer.py +++ b/infrahub_sdk/_importer.py @@ -16,12 +16,23 @@ def import_module( module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None ) -> ModuleType: + """Imports a python module. + + Args: + module_path (Path): Absolute path of the module to import. + import_root (Optional[str]): Absolute string path to the current repository. + relative_path (Optional[str]): Relative string path between module_path and import_root. + """ import_root = import_root or str(module_path.parent) file_on_disk = module_path if import_root and relative_path: file_on_disk = Path(import_root, relative_path, module_path.name) + # This is a temporary workaround for to account for issues if "import_root" is a Path instead of a str + # Later we should rework this so that import_root and relative_path are all Path objects. Here we must + # ensure that anything we add to sys.path is a string and not a Path or PosixPath object. + import_root = str(import_root) if import_root not in sys.path: sys.path.append(import_root) diff --git a/infrahub_sdk/checks.py b/infrahub_sdk/checks.py index 85f9ca1e..5b6c2c55 100644 --- a/infrahub_sdk/checks.py +++ b/infrahub_sdk/checks.py @@ -17,7 +17,7 @@ from pathlib import Path from . import InfrahubClient - from .schema import InfrahubCheckDefinitionConfig + from .schema.repository import InfrahubCheckDefinitionConfig INFRAHUB_CHECK_VARIABLE_TO_IMPORT = "INFRAHUB_CHECKS" diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 00bdd19c..9fca57b7 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -50,7 +50,7 @@ from .protocols_base import CoreNode, CoreNodeSync from .queries import get_commit_update_mutation from .query_groups import InfrahubGroupContext, InfrahubGroupContextSync -from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchema +from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI from .store import NodeStore, NodeStoreSync from .timestamp import Timestamp from .types import AsyncRequester, HTTPMethod, SyncRequester @@ -448,12 +448,12 @@ async def get( filters: MutableMapping[str, Any] = {} if id: - if not is_valid_uuid(id) and isinstance(schema, NodeSchema) and schema.default_filter: + if not is_valid_uuid(id) and isinstance(schema, NodeSchemaAPI) and schema.default_filter: filters[schema.default_filter] = id else: filters["ids"] = [id] if hfid: - if isinstance(schema, NodeSchema) and schema.human_friendly_id: + if isinstance(schema, NodeSchemaAPI) and schema.human_friendly_id: filters["hfid"] = hfid else: raise ValueError("Cannot filter by HFID if the node doesn't have an HFID defined") @@ -1916,12 +1916,12 @@ def get( filters: MutableMapping[str, Any] = {} if id: - if not is_valid_uuid(id) and isinstance(schema, NodeSchema) and schema.default_filter: + if not is_valid_uuid(id) and isinstance(schema, NodeSchemaAPI) and schema.default_filter: filters[schema.default_filter] = id else: filters["ids"] = [id] if hfid: - if isinstance(schema, NodeSchema) and schema.human_friendly_id: + if isinstance(schema, NodeSchemaAPI) and schema.human_friendly_id: filters["hfid"] = hfid else: raise ValueError("Cannot filter by HFID if the node doesn't have an HFID defined") diff --git a/infrahub_sdk/code_generator.py b/infrahub_sdk/code_generator.py index 667ba27d..dc0c2ac9 100644 --- a/infrahub_sdk/code_generator.py +++ b/infrahub_sdk/code_generator.py @@ -1,17 +1,19 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Optional, Union import jinja2 from . import protocols as sdk_protocols from .ctl.constants import PROTOCOLS_TEMPLATE from .schema import ( - AttributeSchema, + AttributeSchemaAPI, GenericSchema, - MainSchemaTypes, + GenericSchemaAPI, + MainSchemaTypesAll, NodeSchema, - ProfileSchema, - RelationshipSchema, + NodeSchemaAPI, + ProfileSchemaAPI, + RelationshipSchemaAPI, ) ATTRIBUTE_KIND_MAP = { @@ -40,17 +42,17 @@ class CodeGenerator: - def __init__(self, schema: dict[str, MainSchemaTypes]): - self.generics: dict[str, GenericSchema] = {} - self.nodes: dict[str, NodeSchema] = {} - self.profiles: dict[str, ProfileSchema] = {} + def __init__(self, schema: dict[str, MainSchemaTypesAll]): + self.generics: dict[str, Union[GenericSchemaAPI, GenericSchema]] = {} + self.nodes: dict[str, Union[NodeSchemaAPI, NodeSchema]] = {} + self.profiles: dict[str, ProfileSchemaAPI] = {} for name, schema_type in schema.items(): - if isinstance(schema_type, GenericSchema): + if isinstance(schema_type, (GenericSchemaAPI, GenericSchema)): self.generics[name] = schema_type - if isinstance(schema_type, NodeSchema): + if isinstance(schema_type, (NodeSchemaAPI, NodeSchema)): self.nodes[name] = schema_type - if isinstance(schema_type, ProfileSchema): + if isinstance(schema_type, ProfileSchemaAPI): self.profiles[name] = schema_type self.base_protocols = [ @@ -92,7 +94,7 @@ def _jinja2_filter_inheritance(value: dict[str, Any]) -> str: return ", ".join(inherit_from) @staticmethod - def _jinja2_filter_render_attribute(value: AttributeSchema) -> str: + def _jinja2_filter_render_attribute(value: AttributeSchemaAPI) -> str: attribute_kind: str = ATTRIBUTE_KIND_MAP[value.kind] if value.optional: @@ -101,7 +103,7 @@ def _jinja2_filter_render_attribute(value: AttributeSchema) -> str: return f"{value.name}: {attribute_kind}" @staticmethod - def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = False) -> str: + def _jinja2_filter_render_relationship(value: RelationshipSchemaAPI, sync: bool = False) -> str: name = value.name cardinality = value.cardinality @@ -116,12 +118,12 @@ def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = F @staticmethod def _sort_and_filter_models( - models: Mapping[str, MainSchemaTypes], filters: Optional[list[str]] = None - ) -> list[MainSchemaTypes]: + models: Mapping[str, MainSchemaTypesAll], filters: Optional[list[str]] = None + ) -> list[MainSchemaTypesAll]: if filters is None: filters = ["CoreNode"] - filtered: list[MainSchemaTypes] = [] + filtered: list[MainSchemaTypesAll] = [] for name, model in models.items(): if name in filters: continue diff --git a/infrahub_sdk/ctl/check.py b/infrahub_sdk/ctl/check.py index b7836f91..aa964807 100644 --- a/infrahub_sdk/ctl/check.py +++ b/infrahub_sdk/ctl/check.py @@ -17,7 +17,7 @@ from ..ctl.repository import get_repository_config from ..ctl.utils import catch_exception, execute_graphql_query from ..exceptions import ModuleImportError -from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig +from ..schema.repository import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig app = typer.Typer() console = Console() diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 04d0d8f1..945a24f6 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -38,11 +38,8 @@ from ..ctl.validate import app as validate_app from ..exceptions import GraphQLError, ModuleImportError from ..jinja2 import identify_faulty_jinja_code -from ..schema import ( - InfrahubRepositoryConfig, - MainSchemaTypes, - SchemaRoot, -) +from ..schema import MainSchemaTypesAll, SchemaRoot +from ..schema.repository import InfrahubRepositoryConfig from ..utils import get_branch, write_to_file from ..yaml import SchemaFile from .exporter import dump @@ -364,7 +361,7 @@ def protocols( ) -> None: """Export Python protocols corresponding to a schema.""" - schema: dict[str, MainSchemaTypes] = {} + schema: dict[str, MainSchemaTypesAll] = {} if schemas: schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console) diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index 9081d42f..5f737f67 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -10,7 +10,7 @@ from ..ctl.utils import execute_graphql_query, parse_cli_vars from ..exceptions import ModuleImportError from ..node import InfrahubNode -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig async def run( diff --git a/infrahub_sdk/ctl/render.py b/infrahub_sdk/ctl/render.py index 6e769c86..05122102 100644 --- a/infrahub_sdk/ctl/render.py +++ b/infrahub_sdk/ctl/render.py @@ -1,6 +1,6 @@ from rich.console import Console -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None: diff --git a/infrahub_sdk/ctl/repository.py b/infrahub_sdk/ctl/repository.py index 6f69f9a5..08fcd571 100644 --- a/infrahub_sdk/ctl/repository.py +++ b/infrahub_sdk/ctl/repository.py @@ -12,7 +12,7 @@ from ..ctl.exceptions import FileNotValidError from ..ctl.utils import init_logging from ..graphql import Mutation -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig from ._file import read_file from .parameters import CONFIG_PARAM diff --git a/infrahub_sdk/ctl/transform.py b/infrahub_sdk/ctl/transform.py index e0a85ec2..1cda0940 100644 --- a/infrahub_sdk/ctl/transform.py +++ b/infrahub_sdk/ctl/transform.py @@ -1,6 +1,6 @@ from rich.console import Console -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig def list_transforms(config: InfrahubRepositoryConfig) -> None: diff --git a/infrahub_sdk/ctl/utils.py b/infrahub_sdk/ctl/utils.py index 02ada5a6..2016283e 100644 --- a/infrahub_sdk/ctl/utils.py +++ b/infrahub_sdk/ctl/utils.py @@ -26,7 +26,7 @@ ServerNotReachableError, ServerNotResponsiveError, ) -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig from ..yaml import YamlFile from .client import initialize_client_sync diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index 0e4f1e57..275cc76a 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -14,7 +14,7 @@ UninitializedError, ) from .graphql import Mutation, Query -from .schema import GenericSchema, RelationshipCardinality, RelationshipKind +from .schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind from .utils import compare_lists, get_flat_value from .uuidt import UUIDT @@ -22,7 +22,7 @@ from typing_extensions import Self from .client import InfrahubClient, InfrahubClientSync - from .schema import AttributeSchema, MainSchemaTypes, RelationshipSchema + from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI # pylint: disable=too-many-lines @@ -46,7 +46,7 @@ class Attribute: """Represents an attribute of a Node, including its schema, value, and properties.""" - def __init__(self, name: str, schema: AttributeSchema, data: Union[Any, dict]): + def __init__(self, name: str, schema: AttributeSchemaAPI, data: Union[Any, dict]): """ Args: name (str): The name of the attribute. @@ -143,7 +143,7 @@ def _generate_mutation_query(self) -> dict[str, Any]: class RelatedNodeBase: """Base class for representing a related node in a relationship.""" - def __init__(self, branch: str, schema: RelationshipSchema, data: Union[Any, dict], name: Optional[str] = None): + def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Union[Any, dict], name: Optional[str] = None): """ Args: branch (str): The branch where the related node resides. @@ -300,7 +300,7 @@ def __init__( self, client: InfrahubClient, branch: str, - schema: RelationshipSchema, + schema: RelationshipSchemaAPI, data: Union[Any, dict], name: Optional[str] = None, ): @@ -347,7 +347,7 @@ def __init__( self, client: InfrahubClientSync, branch: str, - schema: RelationshipSchema, + schema: RelationshipSchemaAPI, data: Union[Any, dict], name: Optional[str] = None, ): @@ -390,7 +390,7 @@ def get(self) -> InfrahubNodeSync: class RelationshipManagerBase: """Base class for RelationshipManager and RelationshipManagerSync""" - def __init__(self, name: str, branch: str, schema: RelationshipSchema): + def __init__(self, name: str, branch: str, schema: RelationshipSchemaAPI): """ Args: name (str): The name of the relationship. @@ -473,7 +473,7 @@ def __init__( client: InfrahubClient, node: InfrahubNode, branch: str, - schema: RelationshipSchema, + schema: RelationshipSchemaAPI, data: Union[Any, dict], ): """ @@ -568,7 +568,7 @@ def __init__( client: InfrahubClientSync, node: InfrahubNodeSync, branch: str, - schema: RelationshipSchema, + schema: RelationshipSchemaAPI, data: Union[Any, dict], ): """ @@ -657,12 +657,12 @@ def remove(self, data: Union[str, RelatedNodeSync, dict]) -> None: class InfrahubNodeBase: """Base class for InfrahubNode and InfrahubNodeSync""" - def __init__(self, schema: MainSchemaTypes, branch: str, data: Optional[dict] = None) -> None: + def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: Optional[dict] = None) -> None: """ Args: - schema (MainSchemaTypes): The schema of the node. - branch (str): The branch where the node resides. - data (Optional[dict]): Optional data to initialize the node. + schema: The schema of the node. + branch: The branch where the node resides. + data: Optional data to initialize the node. """ self._schema = schema self._data = data @@ -1035,16 +1035,16 @@ class InfrahubNode(InfrahubNodeBase): def __init__( self, client: InfrahubClient, - schema: MainSchemaTypes, + schema: MainSchemaTypesAPI, branch: Optional[str] = None, data: Optional[dict] = None, ) -> None: """ Args: - client (InfrahubClient): The client used to interact with the backend. - schema (MainSchemaTypes): The schema of the node. - branch (Optional[str]): The branch where the node resides. - data (Optional[dict]): Optional data to initialize the node. + client: The client used to interact with the backend. + schema: The schema of the node. + branch: The branch where the node resides. + data: Optional data to initialize the node. """ self._client = client self.__class__ = type(f"{schema.kind}InfrahubNode", (self.__class__,), {}) @@ -1060,7 +1060,7 @@ async def from_graphql( client: InfrahubClient, branch: str, data: dict, - schema: Optional[MainSchemaTypes] = None, + schema: Optional[MainSchemaTypesAPI] = None, timeout: Optional[int] = None, ) -> Self: if not schema: @@ -1146,7 +1146,7 @@ async def save( if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True - if not isinstance(self._schema, GenericSchema): + if not isinstance(self._schema, GenericSchemaAPI): if "CoreGroup" in self._schema.inherit_from: await self._client.group_context.add_related_groups( ids=[self.id], update_group_context=update_group_context @@ -1183,7 +1183,7 @@ async def generate_query_data( ) ) - if isinstance(self._schema, GenericSchema) and fragment: + if isinstance(self._schema, GenericSchemaAPI) and fragment: for child in self._schema.used_by: child_schema = await self._client.schema.get(kind=child) child_node = InfrahubNode(client=self._client, schema=child_schema) @@ -1341,7 +1341,7 @@ async def _process_mutation_result( continue # Process allocated resource from a pool and update attribute - attr.value = object_response[attr_name] + attr.value = object_response[attr_name]["value"] for rel_name in self._relationships: rel = getattr(self, rel_name) @@ -1540,7 +1540,7 @@ class InfrahubNodeSync(InfrahubNodeBase): def __init__( self, client: InfrahubClientSync, - schema: MainSchemaTypes, + schema: MainSchemaTypesAPI, branch: Optional[str] = None, data: Optional[dict] = None, ) -> None: @@ -1565,7 +1565,7 @@ def from_graphql( client: InfrahubClientSync, branch: str, data: dict, - schema: Optional[MainSchemaTypes] = None, + schema: Optional[MainSchemaTypesAPI] = None, timeout: Optional[int] = None, ) -> Self: if not schema: @@ -1648,7 +1648,7 @@ def save( if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True - if not isinstance(self._schema, GenericSchema): + if not isinstance(self._schema, GenericSchemaAPI): if "CoreGroup" in self._schema.inherit_from: self._client.group_context.add_related_groups(ids=[self.id], update_group_context=update_group_context) else: @@ -1681,7 +1681,7 @@ def generate_query_data( ) ) - if isinstance(self._schema, GenericSchema) and fragment: + if isinstance(self._schema, GenericSchemaAPI) and fragment: for child in self._schema.used_by: child_schema = self._client.schema.get(kind=child) child_node = InfrahubNodeSync(client=self._client, schema=child_schema) diff --git a/infrahub_sdk/pytest_plugin/items/base.py b/infrahub_sdk/pytest_plugin/items/base.py index 1452737a..a27948c7 100644 --- a/infrahub_sdk/pytest_plugin/items/base.py +++ b/infrahub_sdk/pytest_plugin/items/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import difflib +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import pytest @@ -11,11 +12,11 @@ from ..models import InfrahubInputOutputTest if TYPE_CHECKING: - from pathlib import Path - - from ...schema import InfrahubRepositoryConfigElement + from ...schema.repository import InfrahubRepositoryConfigElement from ..models import InfrahubTest +_infrahub_config_path_attribute = "infrahub_config_path" + class InfrahubItem(pytest.Item): def __init__( @@ -74,3 +75,16 @@ def repr_failure(self, excinfo: pytest.ExceptionInfo, style: Optional[str] = Non def reportinfo(self) -> tuple[Union[Path, str], Optional[int], str]: return self.path, 0, f"resource: {self.name}" + + @property + def repository_base(self) -> str: + """Return the path to the root of the repository + + This will be an absolute path if --infrahub-config-path is an absolut path as happens when + tests are started from within Infrahub server. + """ + config_path: Path = getattr(self.session, _infrahub_config_path_attribute) + if config_path.is_absolute(): + return str(config_path.parent) + + return str(Path.cwd()) diff --git a/infrahub_sdk/pytest_plugin/items/check.py b/infrahub_sdk/pytest_plugin/items/check.py index dd4194f0..551cc50e 100644 --- a/infrahub_sdk/pytest_plugin/items/check.py +++ b/infrahub_sdk/pytest_plugin/items/check.py @@ -1,12 +1,12 @@ from __future__ import annotations import asyncio +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional import ujson from httpx import HTTPStatusError -from ...checks import get_check_class_instance from ..exceptions import CheckDefinitionError, CheckResultError from ..models import InfrahubTestExpectedResult from .base import InfrahubItem @@ -15,7 +15,7 @@ from pytest import ExceptionInfo from ...checks import InfrahubCheck - from ...schema import InfrahubRepositoryConfigElement + from ...schema.repository import InfrahubRepositoryConfigElement from ..models import InfrahubTest @@ -33,9 +33,12 @@ def __init__( self.check_instance: InfrahubCheck def instantiate_check(self) -> None: - self.check_instance = get_check_class_instance( - check_config=self.resource_config, # type: ignore[arg-type] - search_path=self.session.infrahub_config_path.parent, # type: ignore[attr-defined] + relative_path = ( + str(self.resource_config.file_path.parent) if self.resource_config.file_path.parent != Path() else None # type: ignore[attr-defined] + ) + + self.check_instance = self.resource_config.load_class( # type: ignore[attr-defined] + import_root=self.repository_base, relative_path=relative_path ) def run_check(self, variables: dict[str, Any]) -> Any: diff --git a/infrahub_sdk/pytest_plugin/items/python_transform.py b/infrahub_sdk/pytest_plugin/items/python_transform.py index 6249d8ba..28a92e14 100644 --- a/infrahub_sdk/pytest_plugin/items/python_transform.py +++ b/infrahub_sdk/pytest_plugin/items/python_transform.py @@ -1,12 +1,12 @@ from __future__ import annotations import asyncio +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional import ujson from httpx import HTTPStatusError -from ...transforms import get_transform_class_instance from ..exceptions import OutputMatchError, PythonTransformDefinitionError from ..models import InfrahubTestExpectedResult from .base import InfrahubItem @@ -14,7 +14,7 @@ if TYPE_CHECKING: from pytest import ExceptionInfo - from ...schema import InfrahubRepositoryConfigElement + from ...schema.repository import InfrahubRepositoryConfigElement from ...transforms import InfrahubTransform from ..models import InfrahubTest @@ -33,9 +33,11 @@ def __init__( self.transform_instance: InfrahubTransform def instantiate_transform(self) -> None: - self.transform_instance = get_transform_class_instance( - transform_config=self.resource_config, # type: ignore[arg-type] - search_path=self.session.infrahub_config_path.parent, # type: ignore[attr-defined] + relative_path = ( + str(self.resource_config.file_path.parent) if self.resource_config.file_path.parent != Path() else None # type: ignore[attr-defined] + ) + self.transform_instance = self.resource_config.load_class( # type: ignore[attr-defined] + import_root=self.repository_base, relative_path=relative_path ) def run_transform(self, variables: dict[str, Any]) -> Any: diff --git a/infrahub_sdk/pytest_plugin/utils.py b/infrahub_sdk/pytest_plugin/utils.py index 249525ee..2875c23d 100644 --- a/infrahub_sdk/pytest_plugin/utils.py +++ b/infrahub_sdk/pytest_plugin/utils.py @@ -2,7 +2,7 @@ import yaml -from ..schema import InfrahubRepositoryConfig +from ..schema.repository import InfrahubRepositoryConfig from .exceptions import FileNotValidError diff --git a/infrahub_sdk/query_groups.py b/infrahub_sdk/query_groups.py index eb980e6d..4a21165e 100644 --- a/infrahub_sdk/query_groups.py +++ b/infrahub_sdk/query_groups.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from .client import InfrahubClient, InfrahubClientSync from .node import InfrahubNode, InfrahubNodeSync, RelatedNodeBase - from .schema import MainSchemaTypes + from .schema import MainSchemaTypesAPI class InfrahubGroupContextBase: @@ -63,7 +63,7 @@ def _generate_group_name(self, suffix: Optional[str] = None) -> str: return group_name - def _generate_group_description(self, schema: MainSchemaTypes) -> str: + def _generate_group_description(self, schema: MainSchemaTypesAPI) -> str: """Generate the description of the group from the params and ensure it's not longer than the maximum length of the description field.""" if not self.params: diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema/__init__.py similarity index 54% rename from infrahub_sdk/schema.py rename to infrahub_sdk/schema/__init__.py index f04444ee..e69fa407 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema/__init__.py @@ -4,36 +4,63 @@ from collections import defaultdict from collections.abc import MutableMapping from enum import Enum -from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, Optional, TypedDict, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union from urllib.parse import urlencode import httpx -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import TypeAlias -from ._importer import import_module -from .checks import InfrahubCheck -from .exceptions import ( +from ..exceptions import ( InvalidResponseError, - ModuleImportError, - ResourceNotDefinedError, SchemaNotFoundError, ValidationError, ) -from .generator import InfrahubGenerator -from .graphql import Mutation -from .queries import SCHEMA_HASH_SYNC_STATUS -from .transforms import InfrahubTransform -from .utils import duplicates +from ..graphql import Mutation +from ..queries import SCHEMA_HASH_SYNC_STATUS +from .main import ( + AttributeSchema, + AttributeSchemaAPI, + BranchSupportType, + GenericSchema, + GenericSchemaAPI, + NodeSchema, + NodeSchemaAPI, + ProfileSchemaAPI, + RelationshipCardinality, + RelationshipKind, + RelationshipSchema, + RelationshipSchemaAPI, + SchemaRoot, + SchemaRootAPI, +) if TYPE_CHECKING: - from .client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync - from .node import InfrahubNode, InfrahubNodeSync + from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync + from ..node import InfrahubNode, InfrahubNodeSync InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] + +__all__ = [ + "AttributeSchema", + "AttributeSchemaAPI", + "BranchSupportType", + "GenericSchema", + "GenericSchemaAPI", + "NodeSchema", + "NodeSchemaAPI", + "ProfileSchemaAPI", + "RelationshipCardinality", + "RelationshipKind", + "RelationshipSchema", + "RelationshipSchemaAPI", + "SchemaRoot", + "SchemaRootAPI", +] + + # pylint: disable=redefined-builtin @@ -43,267 +70,6 @@ class DropdownMutationOptionalArgs(TypedDict): label: Optional[str] -ResourceClass = TypeVar("ResourceClass") - -# --------------------------------------------------------------------------------- -# Repository Configuration file -# --------------------------------------------------------------------------------- - - -class InfrahubRepositoryConfigElement(BaseModel): - """Class to regroup all elements of the infrahub configuration for a repository for typing purpose.""" - - -class InfrahubRepositoryArtifactDefinitionConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the artifact definition") - artifact_name: Optional[str] = Field(default=None, description="Name of the artifact created from this definition") - parameters: dict[str, Any] = Field(..., description="The input parameters required to render this artifact") - content_type: str = Field(..., description="The content type of the rendered artifact") - targets: str = Field(..., description="The group to target when creating artifacts") - transformation: str = Field(..., description="The transformation to use.") - - -class InfrahubJinja2TransformConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the transform") - query: str = Field(..., description="The name of the GraphQL Query") - template_path: Path = Field(..., description="The path within the repository of the template file") - description: Optional[str] = Field(default=None, description="Description for this transform") - - @property - def template_path_value(self) -> str: - return str(self.template_path) - - @property - def payload(self) -> dict[str, str]: - data = self.model_dump(exclude_none=True) - data["template_path"] = self.template_path_value - return data - - -class InfrahubCheckDefinitionConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the Check Definition") - file_path: Path = Field(..., description="The file within the repository with the check code.") - parameters: dict[str, Any] = Field( - default_factory=dict, description="The input parameters required to run this check" - ) - targets: Optional[str] = Field( - default=None, description="The group to target when running this check, leave blank for global checks" - ) - class_name: str = Field(default="Check", description="The name of the check class to run.") - - def load_class(self, import_root: Optional[str] = None, relative_path: Optional[str] = None) -> type[InfrahubCheck]: - module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) - - if self.class_name not in dir(module): - raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") - - check_class = getattr(module, self.class_name) - - if not issubclass(check_class, InfrahubCheck): - raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Check") - - return check_class - - -class InfrahubGeneratorDefinitionConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the Generator Definition") - file_path: Path = Field(..., description="The file within the repository with the generator code.") - query: str = Field(..., description="The GraphQL query to use as input.") - parameters: dict[str, Any] = Field( - default_factory=dict, description="The input parameters required to run this check" - ) - targets: str = Field(..., description="The group to target when running this generator") - class_name: str = Field(default="Generator", description="The name of the generator class to run.") - convert_query_response: bool = Field( - default=False, - description="Decide if the generator should convert the result of the GraphQL query to SDK InfrahubNode objects.", - ) - - def load_class( - self, import_root: Optional[str] = None, relative_path: Optional[str] = None - ) -> type[InfrahubGenerator]: - module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) - - if self.class_name not in dir(module): - raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") - - generator_class = getattr(module, self.class_name) - - if not issubclass(generator_class, InfrahubGenerator): - raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Generator") - - return generator_class - - -class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the Transform") - file_path: Path = Field(..., description="The file within the repository with the transform code.") - class_name: str = Field(default="Transform", description="The name of the transform class to run.") - - def load_class( - self, import_root: Optional[str] = None, relative_path: Optional[str] = None - ) -> type[InfrahubTransform]: - module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) - - if self.class_name not in dir(module): - raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") - - transform_class = getattr(module, self.class_name) - - if not issubclass(transform_class, InfrahubTransform): - raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform") - - return transform_class - - -class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement): - model_config = ConfigDict(extra="forbid") - name: str = Field(..., description="The name of the GraphQL Query") - file_path: Path = Field(..., description="The file within the repository with the query code.") - - def load_query(self, relative_path: str = ".") -> str: - file_name = Path(f"{relative_path}/{self.file_path}") - with file_name.open("r", encoding="UTF-8") as file: - return file.read() - - -RESOURCE_MAP: dict[Any, str] = { - InfrahubJinja2TransformConfig: "jinja2_transforms", - InfrahubCheckDefinitionConfig: "check_definitions", - InfrahubRepositoryArtifactDefinitionConfig: "artifact_definitions", - InfrahubPythonTransformConfig: "python_transforms", - InfrahubGeneratorDefinitionConfig: "generator_definitions", - InfrahubRepositoryGraphQLConfig: "queries", -} - - -class InfrahubRepositoryConfig(BaseModel): - model_config = ConfigDict(extra="forbid") - check_definitions: list[InfrahubCheckDefinitionConfig] = Field( - default_factory=list, description="User defined checks" - ) - schemas: list[Path] = Field(default_factory=list, description="Schema files") - jinja2_transforms: list[InfrahubJinja2TransformConfig] = Field( - default_factory=list, description="Jinja2 data transformations" - ) - artifact_definitions: list[InfrahubRepositoryArtifactDefinitionConfig] = Field( - default_factory=list, description="Artifact definitions" - ) - python_transforms: list[InfrahubPythonTransformConfig] = Field( - default_factory=list, description="Python data transformations" - ) - generator_definitions: list[InfrahubGeneratorDefinitionConfig] = Field( - default_factory=list, description="Generator definitions" - ) - queries: list[InfrahubRepositoryGraphQLConfig] = Field(default_factory=list, description="GraphQL Queries") - - @field_validator( - "check_definitions", - "jinja2_transforms", - "artifact_definitions", - "python_transforms", - "generator_definitions", - "queries", - ) - @classmethod - def unique_items(cls, v: list[Any]) -> list[Any]: - names = [item.name for item in v] - if dups := duplicates(names): - raise ValueError(f"Found multiples element with the same names: {dups}") - return v - - def _has_resource(self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name") -> bool: - for item in getattr(self, RESOURCE_MAP[resource_type]): - if getattr(item, resource_field) == resource_id: - return True - return False - - def _get_resource( - self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name" - ) -> ResourceClass: - for item in getattr(self, RESOURCE_MAP[resource_type]): - if getattr(item, resource_field) == resource_id: - return item - raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") - - def has_jinja2_transform(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) - - def get_jinja2_transform(self, name: str) -> InfrahubJinja2TransformConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) - - def has_check_definition(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig) - - def get_check_definition(self, name: str) -> InfrahubCheckDefinitionConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig) - - def has_artifact_definition(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig) - - def get_artifact_definition(self, name: str) -> InfrahubRepositoryArtifactDefinitionConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig) - - def has_generator_definition(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig) - - def get_generator_definition(self, name: str) -> InfrahubGeneratorDefinitionConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig) - - def has_python_transform(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig) - - def get_python_transform(self, name: str) -> InfrahubPythonTransformConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig) - - def has_query(self, name: str) -> bool: - return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig) - - def get_query(self, name: str) -> InfrahubRepositoryGraphQLConfig: - return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig) - - -# --------------------------------------------------------------------------------- -# Main Infrahub Schema File -# --------------------------------------------------------------------------------- -class FilterSchema(BaseModel): - name: str - kind: str - description: Optional[str] = None - - -class RelationshipCardinality(str, Enum): - ONE = "one" - MANY = "many" - - -class RelationshipDirection(str, Enum): - BIDIR = "bidirectional" - OUTBOUND = "outbound" - INBOUND = "inbound" - - -class BranchSupportType(str, Enum): - AWARE = "aware" - AGNOSTIC = "agnostic" - LOCAL = "local" - - -class RelationshipKind(str, Enum): - GENERIC = "Generic" - ATTRIBUTE = "Attribute" - COMPONENT = "Component" - PARENT = "Parent" - GROUP = "Group" - HIERARCHY = "Hierarchy" - PROFILE = "Profile" - - class DropdownMutation(str, Enum): add = "SchemaDropdownAdd" remove = "SchemaDropdownRemove" @@ -314,205 +80,16 @@ class EnumMutation(str, Enum): remove = "SchemaEnumRemove" -class SchemaState(str, Enum): - PRESENT = "present" - ABSENT = "absent" - - -class AttributeSchema(BaseModel): - id: Optional[str] = None - state: SchemaState = SchemaState.PRESENT - name: str - kind: str - label: Optional[str] = None - description: Optional[str] = None - default_value: Optional[Any] = None - inherited: bool = False - unique: bool = False - branch: Optional[BranchSupportType] = None - optional: bool = False - read_only: bool = False - choices: Optional[list[dict[str, Any]]] = None - enum: Optional[list[Union[str, int]]] = None - max_length: Optional[int] = None - min_length: Optional[int] = None - regex: Optional[str] = None - order_weight: Optional[int] = None - - -class RelationshipSchema(BaseModel): - id: Optional[str] = None - state: SchemaState = SchemaState.PRESENT - name: str - peer: str - direction: RelationshipDirection = RelationshipDirection.BIDIR - kind: RelationshipKind = RelationshipKind.GENERIC - label: Optional[str] = None - description: Optional[str] = None - identifier: Optional[str] = None - inherited: bool = False - cardinality: str = "many" - branch: Optional[BranchSupportType] = None - optional: bool = True - read_only: bool = False - filters: list[FilterSchema] = Field(default_factory=list) - order_weight: Optional[int] = None - - -class BaseNodeSchema(BaseModel): - id: Optional[str] = None - state: SchemaState = SchemaState.PRESENT - name: str - label: Optional[str] = None - namespace: str - description: Optional[str] = None - attributes: list[AttributeSchema] = Field(default_factory=list) - relationships: list[RelationshipSchema] = Field(default_factory=list) - filters: list[FilterSchema] = Field(default_factory=list) - - @property - def kind(self) -> str: - return self.namespace + self.name - - def get_field(self, name: str, raise_on_error: bool = True) -> Union[AttributeSchema, RelationshipSchema, None]: - if attribute_field := self.get_attribute_or_none(name=name): - return attribute_field - - if relationship_field := self.get_relationship_or_none(name=name): - return relationship_field - - if not raise_on_error: - return None - - raise ValueError(f"Unable to find the field {name}") - - def get_attribute(self, name: str) -> AttributeSchema: - for item in self.attributes: - if item.name == name: - return item - raise ValueError(f"Unable to find the attribute {name}") - - def get_attribute_or_none(self, name: str) -> Optional[AttributeSchema]: - for item in self.attributes: - if item.name == name: - return item - return None - - def get_relationship(self, name: str) -> RelationshipSchema: - for item in self.relationships: - if item.name == name: - return item - raise ValueError(f"Unable to find the relationship {name}") - - def get_relationship_or_none(self, name: str) -> Optional[RelationshipSchema]: - for item in self.relationships: - if item.name == name: - return item - return None - - def get_relationship_by_identifier(self, id: str, raise_on_error: bool = True) -> Union[RelationshipSchema, None]: - for item in self.relationships: - if item.identifier == id: - return item - - if not raise_on_error: - return None - - raise ValueError(f"Unable to find the relationship {id}") - - def get_matching_relationship( - self, id: str, direction: RelationshipDirection = RelationshipDirection.BIDIR - ) -> RelationshipSchema: - valid_direction = RelationshipDirection.BIDIR - if direction == RelationshipDirection.INBOUND: - valid_direction = RelationshipDirection.OUTBOUND - elif direction == RelationshipDirection.OUTBOUND: - valid_direction = RelationshipDirection.INBOUND - - for item in self.relationships: - if item.identifier == id and item.direction == valid_direction: - return item - - raise ValueError(f"Unable to find the relationship {id} / ({valid_direction.value})") - - @property - def attribute_names(self) -> list[str]: - return [item.name for item in self.attributes] - - @property - def relationship_names(self) -> list[str]: - return [item.name for item in self.relationships] - - @property - def mandatory_input_names(self) -> list[str]: - return self.mandatory_attribute_names + self.mandatory_relationship_names - - @property - def mandatory_attribute_names(self) -> list[str]: - return [item.name for item in self.attributes if not item.optional and item.default_value is None] - - @property - def mandatory_relationship_names(self) -> list[str]: - return [item.name for item in self.relationships if not item.optional] - - @property - def local_attributes(self) -> list[AttributeSchema]: - return [item for item in self.attributes if not item.inherited] - - @property - def local_relationships(self) -> list[RelationshipSchema]: - return [item for item in self.relationships if not item.inherited] - - @property - def unique_attributes(self) -> list[AttributeSchema]: - return [item for item in self.attributes if item.unique] - - -class GenericSchema(BaseNodeSchema): - """A Generic can be either an Interface or a Union depending if there are some Attributes or Relationships defined.""" - - used_by: list[str] = Field(default_factory=list) - - -class NodeSchema(BaseNodeSchema): - inherit_from: list[str] = Field(default_factory=list) - branch: Optional[BranchSupportType] = None - default_filter: Optional[str] = None - human_friendly_id: Optional[list[str]] = None - - -class ProfileSchema(BaseNodeSchema): - inherit_from: list[str] = Field(default_factory=list) - - -class NodeExtensionSchema(BaseModel): - name: Optional[str] = None - kind: str - description: Optional[str] = None - label: Optional[str] = None - inherit_from: list[str] = Field(default_factory=list) - branch: Optional[BranchSupportType] = None - default_filter: Optional[str] = None - attributes: list[AttributeSchema] = Field(default_factory=list) - relationships: list[RelationshipSchema] = Field(default_factory=list) - - -class SchemaRoot(BaseModel): - version: str - generics: list[GenericSchema] = Field(default_factory=list) - nodes: list[NodeSchema] = Field(default_factory=list) - profiles: list[ProfileSchema] = Field(default_factory=list) - # node_extensions: list[NodeExtensionSchema] = Field(default_factory=list) - - -MainSchemaTypes: TypeAlias = Union[NodeSchema, GenericSchema, ProfileSchema] +MainSchemaTypes: TypeAlias = Union[NodeSchema, GenericSchema] +MainSchemaTypesAPI: TypeAlias = Union[NodeSchemaAPI, GenericSchemaAPI, ProfileSchemaAPI] +MainSchemaTypesAll: TypeAlias = Union[NodeSchema, GenericSchema, NodeSchemaAPI, GenericSchemaAPI, ProfileSchemaAPI] class InfrahubSchemaBase: def validate(self, data: dict[str, Any]) -> None: SchemaRoot(**data) - def validate_data_against_schema(self, schema: MainSchemaTypes, data: dict) -> None: + def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -> None: for key in data.keys(): if key not in schema.relationship_names + schema.attribute_names: identifier = f"{schema.kind}" @@ -523,7 +100,7 @@ def validate_data_against_schema(self, schema: MainSchemaTypes, data: dict) -> N def generate_payload_create( self, - schema: MainSchemaTypes, + schema: MainSchemaTypesAPI, data: dict, source: Optional[str] = None, owner: Optional[str] = None, @@ -599,7 +176,7 @@ async def get( branch: Optional[str] = None, refresh: bool = False, timeout: Optional[int] = None, - ) -> MainSchemaTypes: + ) -> MainSchemaTypesAPI: branch = branch or self.client.default_branch kind_str = self._get_schema_name(schema=kind) @@ -622,7 +199,7 @@ async def get( async def all( self, branch: Optional[str] = None, refresh: bool = False, namespaces: Optional[list[str]] = None - ) -> MutableMapping[str, MainSchemaTypes]: + ) -> MutableMapping[str, MainSchemaTypesAPI]: """Retrieve the entire schema for a given branch. if present in cache, the schema will be served from the cache, unless refresh is set to True @@ -809,7 +386,7 @@ async def add_dropdown_option( async def fetch( self, branch: str, namespaces: Optional[list[str]] = None, timeout: Optional[int] = None - ) -> MutableMapping[str, MainSchemaTypes]: + ) -> MutableMapping[str, MainSchemaTypesAPI]: """Fetch the schema from the server for a given branch. Args: @@ -830,17 +407,17 @@ async def fetch( data: MutableMapping[str, Any] = response.json() - nodes: MutableMapping[str, MainSchemaTypes] = {} + nodes: MutableMapping[str, MainSchemaTypesAPI] = {} for node_schema in data.get("nodes", []): - node = NodeSchema(**node_schema) + node = NodeSchemaAPI(**node_schema) nodes[node.kind] = node for generic_schema in data.get("generics", []): - generic = GenericSchema(**generic_schema) + generic = GenericSchemaAPI(**generic_schema) nodes[generic.kind] = generic for profile_schema in data.get("profiles", []): - profile = ProfileSchema(**profile_schema) + profile = ProfileSchemaAPI(**profile_schema) nodes[profile.kind] = profile return nodes @@ -853,7 +430,7 @@ def __init__(self, client: InfrahubClientSync): def all( self, branch: Optional[str] = None, refresh: bool = False, namespaces: Optional[list[str]] = None - ) -> MutableMapping[str, MainSchemaTypes]: + ) -> MutableMapping[str, MainSchemaTypesAPI]: """Retrieve the entire schema for a given branch. if present in cache, the schema will be served from the cache, unless refresh is set to True @@ -878,7 +455,7 @@ def get( branch: Optional[str] = None, refresh: bool = False, timeout: Optional[int] = None, - ) -> MainSchemaTypes: + ) -> MainSchemaTypesAPI: branch = branch or self.client.default_branch kind_str = self._get_schema_name(schema=kind) @@ -901,7 +478,7 @@ def get( def _get_kind_and_attribute_schema( self, kind: Union[str, InfrahubNodeTypes], attribute: str, branch: Optional[str] = None - ) -> tuple[str, AttributeSchema]: + ) -> tuple[str, AttributeSchemaAPI]: node_kind: str = kind._schema.kind if not isinstance(kind, str) else kind node_schema = self.client.schema.get(kind=node_kind, branch=branch) schema_attr = node_schema.get_attribute(name=attribute) @@ -1013,7 +590,7 @@ def add_dropdown_option( def fetch( self, branch: str, namespaces: Optional[list[str]] = None, timeout: Optional[int] = None - ) -> MutableMapping[str, MainSchemaTypes]: + ) -> MutableMapping[str, MainSchemaTypesAPI]: """Fetch the schema from the server for a given branch. Args: @@ -1034,17 +611,17 @@ def fetch( data: MutableMapping[str, Any] = response.json() - nodes: MutableMapping[str, MainSchemaTypes] = {} + nodes: MutableMapping[str, MainSchemaTypesAPI] = {} for node_schema in data.get("nodes", []): - node = NodeSchema(**node_schema) + node = NodeSchemaAPI(**node_schema) nodes[node.kind] = node for generic_schema in data.get("generics", []): - generic = GenericSchema(**generic_schema) + generic = GenericSchemaAPI(**generic_schema) nodes[generic.kind] = generic for profile_schema in data.get("profiles", []): - profile = ProfileSchema(**profile_schema) + profile = ProfileSchemaAPI(**profile_schema) nodes[profile.kind] = profile return nodes diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py new file mode 100644 index 00000000..9fedd234 --- /dev/null +++ b/infrahub_sdk/schema/main.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import warnings +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +if TYPE_CHECKING: + from ..node import InfrahubNode, InfrahubNodeSync + + InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] + + +class RelationshipCardinality(str, Enum): + ONE = "one" + MANY = "many" + + +class BranchSupportType(str, Enum): + AWARE = "aware" + AGNOSTIC = "agnostic" + LOCAL = "local" + + +class RelationshipKind(str, Enum): + GENERIC = "Generic" + ATTRIBUTE = "Attribute" + COMPONENT = "Component" + PARENT = "Parent" + GROUP = "Group" + HIERARCHY = "Hierarchy" + PROFILE = "Profile" + + +class RelationshipDirection(str, Enum): + BIDIR = "bidirectional" + OUTBOUND = "outbound" + INBOUND = "inbound" + + +class AttributeKind(str, Enum): + ID = "ID" + TEXT = "Text" + STRING = "String" # deprecated + TEXTAREA = "TextArea" + DATETIME = "DateTime" + NUMBER = "Number" + DROPDOWN = "Dropdown" + EMAIL = "Email" + PASSWORD = "Password" # noqa: S105 + HASHEDPASSWORD = "HashedPassword" + URL = "URL" + FILE = "File" + MAC_ADDRESS = "MacAddress" + COLOR = "Color" + BANDWIDTH = "Bandwidth" + IPHOST = "IPHost" + IPNETWORK = "IPNetwork" + BOOLEAN = "Boolean" + CHECKBOX = "Checkbox" + LIST = "List" + JSON = "JSON" + ANY = "Any" + + def __getattr__(self, name: str) -> Any: + if name == "STRING": + warnings.warn( + f"{name} is deprecated and will be removed in future versions.", + DeprecationWarning, + stacklevel=2, + ) + return super().__getattribute__(name) + + +class SchemaState(str, Enum): + PRESENT = "present" + ABSENT = "absent" + + +class AllowOverrideType(str, Enum): + NONE = "none" + ANY = "any" + + +class RelationshipDeleteBehavior(str, Enum): + NO_ACTION = "no-action" + CASCADE = "cascade" + + +class AttributeSchema(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + id: Optional[str] = None + state: SchemaState = SchemaState.PRESENT + name: str + kind: AttributeKind + label: Optional[str] = None + description: Optional[str] = None + default_value: Optional[Any] = None + unique: bool = False + branch: Optional[BranchSupportType] = None + optional: bool = False + choices: Optional[list[dict[str, Any]]] = None + enum: Optional[list[Union[str, int]]] = None + max_length: Optional[int] = None + min_length: Optional[int] = None + regex: Optional[str] = None + order_weight: Optional[int] = None + + +class AttributeSchemaAPI(AttributeSchema): + model_config = ConfigDict(use_enum_values=True) + + inherited: bool = False + read_only: bool = False + allow_override: AllowOverrideType = AllowOverrideType.ANY + + +class RelationshipSchema(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + id: Optional[str] = None + state: SchemaState = SchemaState.PRESENT + name: str + peer: str + kind: RelationshipKind = RelationshipKind.GENERIC + label: Optional[str] = None + description: Optional[str] = None + identifier: Optional[str] = None + min_count: Optional[int] = None + max_count: Optional[int] = None + direction: RelationshipDirection = RelationshipDirection.BIDIR + on_delete: Optional[RelationshipDeleteBehavior] = None + cardinality: str = "many" + branch: Optional[BranchSupportType] = None + optional: bool = True + order_weight: Optional[int] = None + + +class RelationshipSchemaAPI(RelationshipSchema): + model_config = ConfigDict(use_enum_values=True) + + inherited: bool = False + read_only: bool = False + hierarchical: Optional[str] = None + allow_override: AllowOverrideType = AllowOverrideType.ANY + + +class BaseSchemaAttrRel(BaseModel): + attributes: list[AttributeSchema] = Field(default_factory=list) + relationships: list[RelationshipSchema] = Field(default_factory=list) + + +class BaseSchemaAttrRelAPI(BaseModel): + attributes: list[AttributeSchemaAPI] = Field(default_factory=list) + relationships: list[RelationshipSchemaAPI] = Field(default_factory=list) + + def get_field( + self, name: str, raise_on_error: bool = True + ) -> Union[AttributeSchemaAPI, RelationshipSchemaAPI, None]: + if attribute_field := self.get_attribute_or_none(name=name): + return attribute_field + + if relationship_field := self.get_relationship_or_none(name=name): + return relationship_field + + if not raise_on_error: + return None + + raise ValueError(f"Unable to find the field {name}") + + def get_attribute(self, name: str) -> AttributeSchemaAPI: + for item in self.attributes: + if item.name == name: + return item + raise ValueError(f"Unable to find the attribute {name}") + + def get_attribute_or_none(self, name: str) -> Optional[AttributeSchemaAPI]: + for item in self.attributes: + if item.name == name: + return item + return None + + def get_relationship(self, name: str) -> RelationshipSchemaAPI: + for item in self.relationships: + if item.name == name: + return item + raise ValueError(f"Unable to find the relationship {name}") + + def get_relationship_or_none(self, name: str) -> Optional[RelationshipSchemaAPI]: + for item in self.relationships: + if item.name == name: + return item + return None + + def get_relationship_by_identifier( + self, id: str, raise_on_error: bool = True + ) -> Union[RelationshipSchemaAPI, None]: + for item in self.relationships: + if item.identifier == id: + return item + + if not raise_on_error: + return None + + raise ValueError(f"Unable to find the relationship {id}") + + def get_matching_relationship( + self, id: str, direction: RelationshipDirection = RelationshipDirection.BIDIR + ) -> RelationshipSchemaAPI: + valid_direction = RelationshipDirection.BIDIR + if direction == RelationshipDirection.INBOUND: + valid_direction = RelationshipDirection.OUTBOUND + elif direction == RelationshipDirection.OUTBOUND: + valid_direction = RelationshipDirection.INBOUND + for item in self.relationships: + if item.identifier == id and item.direction == valid_direction: + return item + raise ValueError(f"Unable to find the relationship {id} / ({valid_direction.value})") + + @property + def attribute_names(self) -> list[str]: + return [item.name for item in self.attributes] + + @property + def relationship_names(self) -> list[str]: + return [item.name for item in self.relationships] + + @property + def mandatory_input_names(self) -> list[str]: + return self.mandatory_attribute_names + self.mandatory_relationship_names + + @property + def mandatory_attribute_names(self) -> list[str]: + return [item.name for item in self.attributes if not item.optional and item.default_value is None] + + @property + def mandatory_relationship_names(self) -> list[str]: + return [item.name for item in self.relationships if not item.optional] + + @property + def local_attributes(self) -> list[AttributeSchemaAPI]: + return [item for item in self.attributes if not item.inherited] + + @property + def local_relationships(self) -> list[RelationshipSchemaAPI]: + return [item for item in self.relationships if not item.inherited] + + @property + def unique_attributes(self) -> list[AttributeSchemaAPI]: + return [item for item in self.attributes if item.unique] + + +class BaseSchema(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + id: Optional[str] = None + state: SchemaState = SchemaState.PRESENT + name: str + label: Optional[str] = None + namespace: str + description: Optional[str] = None + include_in_menu: Optional[bool] = None + menu_placement: Optional[str] = None + icon: Optional[str] = None + uniqueness_constraints: Optional[list[list[str]]] = None + documentation: Optional[str] = None + + @property + def kind(self) -> str: + return self.namespace + self.name + + +class GenericSchema(BaseSchema, BaseSchemaAttrRel): + def convert_api(self) -> GenericSchemaAPI: + return GenericSchemaAPI(**self.model_dump()) + + +class GenericSchemaAPI(BaseSchema, BaseSchemaAttrRelAPI): + """A Generic can be either an Interface or a Union depending if there are some Attributes or Relationships defined.""" + + hash: Optional[str] = None + used_by: list[str] = Field(default_factory=list) + + +class BaseNodeSchema(BaseSchema): + model_config = ConfigDict(use_enum_values=True) + + inherit_from: list[str] = Field(default_factory=list) + branch: Optional[BranchSupportType] = None + default_filter: Optional[str] = None + human_friendly_id: Optional[list[str]] = None + generate_profile: Optional[bool] = None + parent: Optional[str] = None + children: Optional[str] = None + + +class NodeSchema(BaseNodeSchema, BaseSchemaAttrRel): + def convert_api(self) -> NodeSchemaAPI: + return NodeSchemaAPI(**self.model_dump()) + + +class NodeSchemaAPI(BaseNodeSchema, BaseSchemaAttrRelAPI): + hash: Optional[str] = None + hierarchy: Optional[str] = None + + +class ProfileSchemaAPI(BaseSchema, BaseSchemaAttrRelAPI): + inherit_from: list[str] = Field(default_factory=list) + + +class NodeExtensionSchema(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + name: Optional[str] = None + kind: str + description: Optional[str] = None + label: Optional[str] = None + inherit_from: list[str] = Field(default_factory=list) + branch: Optional[BranchSupportType] = None + default_filter: Optional[str] = None + attributes: list[AttributeSchema] = Field(default_factory=list) + relationships: list[RelationshipSchema] = Field(default_factory=list) + + +class SchemaRoot(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + version: str + generics: list[GenericSchema] = Field(default_factory=list) + nodes: list[NodeSchema] = Field(default_factory=list) + node_extensions: list[NodeExtensionSchema] = Field(default_factory=list) + + def to_schema_dict(self) -> dict[str, Any]: + return self.model_dump(exclude_unset=True, exclude_defaults=True) + + +class SchemaRootAPI(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + version: str + generics: list[GenericSchemaAPI] = Field(default_factory=list) + nodes: list[NodeSchemaAPI] = Field(default_factory=list) + profiles: list[ProfileSchemaAPI] = Field(default_factory=list) diff --git a/infrahub_sdk/schema/repository.py b/infrahub_sdk/schema/repository.py new file mode 100644 index 00000000..6d32f727 --- /dev/null +++ b/infrahub_sdk/schema/repository.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from .._importer import import_module +from ..checks import InfrahubCheck +from ..exceptions import ( + ModuleImportError, + ResourceNotDefinedError, +) +from ..generator import InfrahubGenerator +from ..transforms import InfrahubTransform +from ..utils import duplicates + +if TYPE_CHECKING: + from ..node import InfrahubNode, InfrahubNodeSync + + InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] + +ResourceClass = TypeVar("ResourceClass") + + +class InfrahubRepositoryConfigElement(BaseModel): + """Class to regroup all elements of the infrahub configuration for a repository for typing purpose.""" + + +class InfrahubRepositoryArtifactDefinitionConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the artifact definition") + artifact_name: Optional[str] = Field(default=None, description="Name of the artifact created from this definition") + parameters: dict[str, Any] = Field(..., description="The input parameters required to render this artifact") + content_type: str = Field(..., description="The content type of the rendered artifact") + targets: str = Field(..., description="The group to target when creating artifacts") + transformation: str = Field(..., description="The transformation to use.") + + +class InfrahubJinja2TransformConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the transform") + query: str = Field(..., description="The name of the GraphQL Query") + template_path: Path = Field(..., description="The path within the repository of the template file") + description: Optional[str] = Field(default=None, description="Description for this transform") + + @property + def template_path_value(self) -> str: + return str(self.template_path) + + @property + def payload(self) -> dict[str, str]: + data = self.model_dump(exclude_none=True) + data["template_path"] = self.template_path_value + return data + + +class InfrahubCheckDefinitionConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the Check Definition") + file_path: Path = Field(..., description="The file within the repository with the check code.") + parameters: dict[str, Any] = Field( + default_factory=dict, description="The input parameters required to run this check" + ) + targets: Optional[str] = Field( + default=None, description="The group to target when running this check, leave blank for global checks" + ) + class_name: str = Field(default="Check", description="The name of the check class to run.") + + def load_class(self, import_root: Optional[str] = None, relative_path: Optional[str] = None) -> type[InfrahubCheck]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + check_class = getattr(module, self.class_name) + + if not issubclass(check_class, InfrahubCheck): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Check") + + return check_class + + +class InfrahubGeneratorDefinitionConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the Generator Definition") + file_path: Path = Field(..., description="The file within the repository with the generator code.") + query: str = Field(..., description="The GraphQL query to use as input.") + parameters: dict[str, Any] = Field( + default_factory=dict, description="The input parameters required to run this check" + ) + targets: str = Field(..., description="The group to target when running this generator") + class_name: str = Field(default="Generator", description="The name of the generator class to run.") + convert_query_response: bool = Field( + default=False, + description="Decide if the generator should convert the result of the GraphQL query to SDK InfrahubNode objects.", + ) + + def load_class( + self, import_root: Optional[str] = None, relative_path: Optional[str] = None + ) -> type[InfrahubGenerator]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + generator_class = getattr(module, self.class_name) + + if not issubclass(generator_class, InfrahubGenerator): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Generator") + + return generator_class + + +class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the Transform") + file_path: Path = Field(..., description="The file within the repository with the transform code.") + class_name: str = Field(default="Transform", description="The name of the transform class to run.") + + def load_class( + self, import_root: Optional[str] = None, relative_path: Optional[str] = None + ) -> type[InfrahubTransform]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + transform_class = getattr(module, self.class_name) + + if not issubclass(transform_class, InfrahubTransform): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform") + + return transform_class + + +class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement): + model_config = ConfigDict(extra="forbid") + name: str = Field(..., description="The name of the GraphQL Query") + file_path: Path = Field(..., description="The file within the repository with the query code.") + + def load_query(self, relative_path: str = ".") -> str: + file_name = Path(f"{relative_path}/{self.file_path}") + with file_name.open("r", encoding="UTF-8") as file: + return file.read() + + +RESOURCE_MAP: dict[Any, str] = { + InfrahubJinja2TransformConfig: "jinja2_transforms", + InfrahubCheckDefinitionConfig: "check_definitions", + InfrahubRepositoryArtifactDefinitionConfig: "artifact_definitions", + InfrahubPythonTransformConfig: "python_transforms", + InfrahubGeneratorDefinitionConfig: "generator_definitions", + InfrahubRepositoryGraphQLConfig: "queries", +} + + +class InfrahubRepositoryConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + check_definitions: list[InfrahubCheckDefinitionConfig] = Field( + default_factory=list, description="User defined checks" + ) + schemas: list[Path] = Field(default_factory=list, description="Schema files") + jinja2_transforms: list[InfrahubJinja2TransformConfig] = Field( + default_factory=list, description="Jinja2 data transformations" + ) + artifact_definitions: list[InfrahubRepositoryArtifactDefinitionConfig] = Field( + default_factory=list, description="Artifact definitions" + ) + python_transforms: list[InfrahubPythonTransformConfig] = Field( + default_factory=list, description="Python data transformations" + ) + generator_definitions: list[InfrahubGeneratorDefinitionConfig] = Field( + default_factory=list, description="Generator definitions" + ) + queries: list[InfrahubRepositoryGraphQLConfig] = Field(default_factory=list, description="GraphQL Queries") + + @field_validator( + "check_definitions", + "jinja2_transforms", + "artifact_definitions", + "python_transforms", + "generator_definitions", + "queries", + ) + @classmethod + def unique_items(cls, v: list[Any]) -> list[Any]: + names = [item.name for item in v] + if dups := duplicates(names): + raise ValueError(f"Found multiples element with the same names: {dups}") + return v + + def _has_resource(self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name") -> bool: + for item in getattr(self, RESOURCE_MAP[resource_type]): + if getattr(item, resource_field) == resource_id: + return True + return False + + def _get_resource( + self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name" + ) -> ResourceClass: + for item in getattr(self, RESOURCE_MAP[resource_type]): + if getattr(item, resource_field) == resource_id: + return item + raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") + + def has_jinja2_transform(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) + + def get_jinja2_transform(self, name: str) -> InfrahubJinja2TransformConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) + + def has_check_definition(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig) + + def get_check_definition(self, name: str) -> InfrahubCheckDefinitionConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig) + + def has_artifact_definition(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig) + + def get_artifact_definition(self, name: str) -> InfrahubRepositoryArtifactDefinitionConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig) + + def has_generator_definition(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig) + + def get_generator_definition(self, name: str) -> InfrahubGeneratorDefinitionConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig) + + def has_python_transform(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig) + + def get_python_transform(self, name: str) -> InfrahubPythonTransformConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig) + + def has_query(self, name: str) -> bool: + return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig) + + def get_query(self, name: str) -> InfrahubRepositoryGraphQLConfig: + return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig) diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 24278f47..5650752d 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from ..client import InfrahubClient -from ..schema import MainSchemaTypes +from ..schema import MainSchemaTypesAPI from ..yaml import InfrahubFile, InfrahubFileKind @@ -19,7 +19,7 @@ def enrich_node(cls, data: dict, context: dict) -> dict: async def create_node( cls, client: InfrahubClient, - schema: MainSchemaTypes, + schema: MainSchemaTypesAPI, data: dict, context: Optional[dict] = None, branch: Optional[str] = None, diff --git a/infrahub_sdk/testing/__init__.py b/infrahub_sdk/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/infrahub_sdk/testing/docker.py b/infrahub_sdk/testing/docker.py new file mode 100644 index 00000000..4beccb16 --- /dev/null +++ b/infrahub_sdk/testing/docker.py @@ -0,0 +1,18 @@ +import pytest +from infrahub_testcontainers.helpers import TestInfrahubDocker + +from .. import Config, InfrahubClient, InfrahubClientSync + + +class TestInfrahubDockerClient(TestInfrahubDocker): + @pytest.fixture(scope="class") + def client(self, infrahub_port: int) -> InfrahubClient: + return InfrahubClient( + config=Config(username="admin", password="infrahub", address=f"http://localhost:{infrahub_port}") # noqa: S106 + ) + + @pytest.fixture(scope="class") + def client_sync(self, infrahub_port: int) -> InfrahubClientSync: + return InfrahubClientSync( + config=Config(username="admin", password="infrahub", address=f"http://localhost:{infrahub_port}") # noqa: S106 + ) diff --git a/infrahub_sdk/testing/repository.py b/infrahub_sdk/testing/repository.py new file mode 100644 index 00000000..28d770a3 --- /dev/null +++ b/infrahub_sdk/testing/repository.py @@ -0,0 +1,103 @@ +import asyncio +import shutil +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Optional + +from git.repo import Repo + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.graphql import Mutation +from infrahub_sdk.protocols import CoreGenericRepository + + +# NOTE we shouldn't duplicate this, need to figure out a better solution +class RepositorySyncStatus(str, Enum): + UNKNOWN = "unknown" + IN_SYNC = "in-sync" + ERROR_IMPORT = "error-import" + SYNCING = "syncing" + + +class GitRepoType(str, Enum): + INTEGRATED = "CoreRepository" + READ_ONLY = "CoreReadOnlyRepository" + + +@dataclass +class GitRepo: + name: str + src_directory: Path + dst_directory: Path + + type: GitRepoType = GitRepoType.INTEGRATED + + _repo: Optional[Repo] = None + initial_branch: str = "main" + directories_to_ignore: list[str] = field(default_factory=list) + remote_directory_name: str = "/remote" + _branches: list[str] = field(default_factory=list) + + @property + def repo(self) -> Repo: + if self._repo: + return self._repo + raise ValueError("Repo hasn't been initialized yet") + + def __post_init__(self) -> None: + self.init() + + @property + def path(self) -> str: + return str(self.src_directory / self.name) + + def init(self) -> None: + shutil.copytree( + src=self.src_directory, + dst=self.dst_directory / self.name, + ignore=shutil.ignore_patterns(".git"), + ) + self._repo = Repo.init(self.dst_directory / self.name, initial_branch=self.initial_branch) + for untracked in self.repo.untracked_files: + self.repo.index.add(untracked) + self.repo.index.commit("First commit") + + self.repo.git.checkout(self.initial_branch) + + async def add_to_infrahub(self, client: InfrahubClient, branch: Optional[str] = None) -> dict: + input_data = { + "data": { + "name": {"value": self.name}, + "location": {"value": f"{self.remote_directory_name}/{self.name}"}, + }, + } + + query = Mutation( + mutation=f"{self.type.value}Create", + input_data=input_data, + query={"ok": None}, + ) + + return await client.execute_graphql( + query=query.render(), branch_name=branch or self.initial_branch, tracker="mutation-repository-create" + ) + + async def wait_for_sync_to_complete( + self, client: InfrahubClient, branch: Optional[str] = None, interval: int = 5, retries: int = 6 + ) -> bool: + for _ in range(retries): + repo = await client.get( + kind=CoreGenericRepository, # type: ignore[type-abstract] + name__value=self.name, + branch=branch or self.initial_branch, + ) + status = repo.sync_status.value + if status == RepositorySyncStatus.IN_SYNC.value: + return True + if status == RepositorySyncStatus.ERROR_IMPORT.value: + return False + + await asyncio.sleep(interval) + + return False diff --git a/infrahub_sdk/testing/schemas/__init__.py b/infrahub_sdk/testing/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/infrahub_sdk/testing/schemas/car_person.py b/infrahub_sdk/testing/schemas/car_person.py new file mode 100644 index 00000000..632b7ad9 --- /dev/null +++ b/infrahub_sdk/testing/schemas/car_person.py @@ -0,0 +1,145 @@ +import pytest + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.node import InfrahubNode +from infrahub_sdk.schema.main import AttributeKind, NodeSchema, RelationshipKind, SchemaRoot +from infrahub_sdk.schema.main import AttributeSchema as Attr +from infrahub_sdk.schema.main import RelationshipSchema as Rel + +NAMESPACE = "Testing" + +TESTING_MANUFACTURER = f"{NAMESPACE}Manufacturer" +TESTING_PERSON = f"{NAMESPACE}Person" +TESTING_CAR = f"{NAMESPACE}Car" + + +class SchemaCarPerson: + @pytest.fixture(scope="class") + def schema_person_base(self) -> NodeSchema: + return NodeSchema( + name="Person", + namespace=NAMESPACE, + include_in_menu=True, + label="Person", + human_friendly_id=["name__value"], + attributes=[ + Attr(name="name", kind=AttributeKind.TEXT, unique=True), + Attr(name="description", kind=AttributeKind.TEXT, optional=True), + Attr(name="height", kind=AttributeKind.NUMBER, optional=True), + Attr(name="age", kind=AttributeKind.NUMBER, optional=True), + ], + relationships=[ + Rel(name="cars", kind=RelationshipKind.GENERIC, optional=True, peer=TESTING_CAR, cardinality="many") + ], + ) + + @pytest.fixture(scope="class") + def schema_car_base(self) -> NodeSchema: + return NodeSchema( + name="Car", + namespace=NAMESPACE, + include_in_menu=True, + default_filter="name__value", + human_friendly_id=["owner__name__value", "name__value"], + label="Car", + attributes=[ + Attr(name="name", kind=AttributeKind.TEXT), + Attr(name="description", kind=AttributeKind.TEXT, optional=True), + Attr(name="color", kind=AttributeKind.TEXT), + ], + relationships=[ + Rel( + name="owner", + kind=RelationshipKind.ATTRIBUTE, + optional=False, + peer=TESTING_PERSON, + cardinality="one", + ), + Rel( + name="manufacturer", + kind=RelationshipKind.ATTRIBUTE, + optional=False, + peer=TESTING_MANUFACTURER, + cardinality="one", + identifier="car__manufacturer", + ), + ], + ) + + @pytest.fixture(scope="class") + def schema_manufacturer_base(self) -> NodeSchema: + return NodeSchema( + name="Manufacturer", + namespace=NAMESPACE, + include_in_menu=True, + label="Manufacturer", + human_friendly_id=["name__value"], + attributes=[ + Attr(name="name", kind=AttributeKind.TEXT), + Attr(name="description", kind=AttributeKind.TEXT, optional=True), + ], + relationships=[ + Rel( + name="cars", + kind=RelationshipKind.GENERIC, + optional=True, + peer=TESTING_CAR, + cardinality="many", + identifier="car__manufacturer", + ), + Rel( + name="customers", + kind=RelationshipKind.GENERIC, + optional=True, + peer=TESTING_PERSON, + cardinality="many", + identifier="person__manufacturer", + ), + ], + ) + + @pytest.fixture(scope="class") + def schema_base( + self, + schema_car_base: NodeSchema, + schema_person_base: NodeSchema, + schema_manufacturer_base: NodeSchema, + ) -> SchemaRoot: + return SchemaRoot(version="1.0", nodes=[schema_car_base, schema_person_base, schema_manufacturer_base]) + + async def create_persons(self, client: InfrahubClient, branch: str) -> list[InfrahubNode]: + john = await client.create(kind=TESTING_PERSON, name="John Doe", branch=branch) + await john.save() + + jane = await client.create(kind=TESTING_PERSON, name="Jane Doe", branch=branch) + await jane.save() + + return [john, jane] + + async def create_manufacturers(self, client: InfrahubClient, branch: str) -> list[InfrahubNode]: + obj1 = await client.create(kind=TESTING_MANUFACTURER, name="Volkswagen", branch=branch) + await obj1.save() + + obj2 = await client.create(kind=TESTING_MANUFACTURER, name="Renault", branch=branch) + await obj2.save() + + obj3 = await client.create(kind=TESTING_MANUFACTURER, name="Mercedes", branch=branch) + await obj3.save() + + return [obj1, obj2, obj3] + + async def create_initial_data(self, client: InfrahubClient, branch: str) -> dict[str, list[InfrahubNode]]: + persons = await self.create_persons(client=client, branch=branch) + manufacturers = await self.create_manufacturers(client=client, branch=branch) + + car10 = await client.create( + kind=TESTING_CAR, name="Golf", color="Black", manufacturer=manufacturers[0].id, owner=persons[0].id + ) + await car10.save() + + car20 = await client.create( + kind=TESTING_CAR, name="Megane", color="Red", manufacturer=manufacturers[1].id, owner=persons[1].id + ) + await car20.save() + + return {TESTING_PERSON: persons, TESTING_CAR: [car10, car20], TESTING_MANUFACTURER: manufacturers} diff --git a/infrahub_sdk/transfer/exporter/json.py b/infrahub_sdk/transfer/exporter/json.py index 11e212f7..fadda9b2 100644 --- a/infrahub_sdk/transfer/exporter/json.py +++ b/infrahub_sdk/transfer/exporter/json.py @@ -9,13 +9,13 @@ from ...client import InfrahubClient from ...queries import QUERY_RELATIONSHIPS -from ...schema import MainSchemaTypes, NodeSchema +from ...schema import MainSchemaTypesAPI, NodeSchemaAPI from ..constants import ILLEGAL_NAMESPACES from ..exceptions import FileAlreadyExistsError, InvalidNamespaceError from .interface import ExporterInterface if TYPE_CHECKING: - from .node import InfrahubNode + from ...node import InfrahubNode class LineDelimitedJSONExporter(ExporterInterface): @@ -32,7 +32,7 @@ def wrapped_task_output(self, start: str, end: str = "[green]done") -> Generator self.console.print(f"{end}") def identify_many_to_many_relationships( - self, node_schema_map: dict[str, MainSchemaTypes] + self, node_schema_map: dict[str, MainSchemaTypesAPI] ) -> dict[tuple[str, str], str]: # Identify many to many relationships by src/dst couples many_relationship_identifiers: dict[tuple[str, str], str] = {} @@ -60,7 +60,7 @@ def identify_many_to_many_relationships( return many_relationship_identifiers async def retrieve_many_to_many_relationships( - self, node_schema_map: dict[str, MainSchemaTypes], branch: str + self, node_schema_map: dict[str, MainSchemaTypesAPI], branch: str ) -> list[dict[str, Any]]: has_remaining_items = True page_number = 1 @@ -113,7 +113,7 @@ async def export( # pylint: disable=too-many-branches node_schema_map = { kind: schema for kind, schema in node_schema_map.items() - if isinstance(schema, NodeSchema) + if isinstance(schema, NodeSchemaAPI) and schema.namespace not in illegal_namespaces and (not exclude or kind not in exclude) } diff --git a/infrahub_sdk/transfer/schema_sorter.py b/infrahub_sdk/transfer/schema_sorter.py index 58f734c4..8b4c5d85 100644 --- a/infrahub_sdk/transfer/schema_sorter.py +++ b/infrahub_sdk/transfer/schema_sorter.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Optional -from ..schema import BaseNodeSchema +from ..schema import NodeSchema from ..topological_sort import DependencyCycleExistsError, topological_sort from .exceptions import SchemaImportError @@ -9,7 +9,7 @@ class InfrahubSchemaTopologicalSorter: def get_sorted_node_schema( self, - schemas: Sequence[BaseNodeSchema], + schemas: Sequence[NodeSchema], required_relationships_only: bool = True, include: Optional[list[str]] = None, ) -> list[set[str]]: diff --git a/infrahub_sdk/transforms.py b/infrahub_sdk/transforms.py index 21c2fb73..eb145783 100644 --- a/infrahub_sdk/transforms.py +++ b/infrahub_sdk/transforms.py @@ -14,7 +14,7 @@ from pathlib import Path from . import InfrahubClient - from .schema import InfrahubPythonTransformConfig + from .schema.repository import InfrahubPythonTransformConfig INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS" diff --git a/poetry.lock b/poetry.lock index 3481ce50..68ba393d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "annotated-types" @@ -35,13 +35,13 @@ trio = ["trio (>=0.23)"] [[package]] name = "astroid" -version = "3.2.4" +version = "3.1.0" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" files = [ - {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, - {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, + {file = "astroid-3.1.0-py3-none-any.whl", hash = "sha256:951798f922990137ac090c53af473db7ab4e70c770e6d7fae0cec59f74411819"}, + {file = "astroid-3.1.0.tar.gz", hash = "sha256:ac248253bfa4bd924a0de213707e7ebeeb3138abeb48d798784ead1e56d419d4"}, ] [package.dependencies] @@ -335,6 +335,28 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "docker" +version = "7.1.0" +description = "A Python library for the Docker Engine API." +optional = false +python-versions = ">=3.8" +files = [ + {file = "docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0"}, + {file = "docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c"}, +] + +[package.dependencies] +pywin32 = {version = ">=304", markers = "sys_platform == \"win32\""} +requests = ">=2.26.0" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["coverage (==7.2.7)", "pytest (==7.4.2)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.1.0)", "ruff (==0.1.8)"] +docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] +ssh = ["paramiko (>=2.4.3)"] +websockets = ["websocket-client (>=1.3.0)"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -566,6 +588,21 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] type = ["pytest-mypy"] +[[package]] +name = "infrahub-testcontainers" +version = "1.1.0b2" +description = "Testcontainers instance for Infrahub to easily build integration tests" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "infrahub_testcontainers-1.1.0b2-py3-none-any.whl", hash = "sha256:40a4f735b988db0f20eeedc68eab2fc40dcfba37382d9836a49bd6dbc282b80a"}, + {file = "infrahub_testcontainers-1.1.0b2.tar.gz", hash = "sha256:fd3738a8f6588c16a8d88944b8f0c9faaa3a9f390cd2817bdabc8e08d4dae6a6"}, +] + +[package.dependencies] +pytest = "*" +testcontainers = ">=4.8,<4.9" + [[package]] name = "iniconfig" version = "2.0.0" @@ -660,13 +697,13 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] [[package]] name = "jinja2" -version = "3.1.4" +version = "3.1.5" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" files = [ - {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, - {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, + {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"}, + {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"}, ] [package.dependencies] @@ -1381,17 +1418,17 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pylint" -version = "3.2.7" +version = "3.1.1" description = "python code static checker" optional = false python-versions = ">=3.8.0" files = [ - {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, - {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, + {file = "pylint-3.1.1-py3-none-any.whl", hash = "sha256:862eddf25dab42704c5f06d3688b8bc19ef4c99ad8a836b6ff260a3b2fbafee1"}, + {file = "pylint-3.1.1.tar.gz", hash = "sha256:c7c2652bf8099c7fb7a63bc6af5c5f8f7b9d7b392fa1d320cb020e222aff28c2"}, ] [package.dependencies] -astroid = ">=3.2.4,<=3.3.0-dev0" +astroid = ">=3.1.0,<=3.2.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -1548,6 +1585,33 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "pywin32" +version = "308" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e"}, + {file = "pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e"}, + {file = "pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c"}, + {file = "pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a"}, + {file = "pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b"}, + {file = "pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6"}, + {file = "pywin32-308-cp312-cp312-win32.whl", hash = "sha256:587f3e19696f4bf96fde9d8a57cec74a57021ad5f204c9e627e15c33ff568897"}, + {file = "pywin32-308-cp312-cp312-win_amd64.whl", hash = "sha256:00b3e11ef09ede56c6a43c71f2d31857cf7c54b0ab6e78ac659497abd2834f47"}, + {file = "pywin32-308-cp312-cp312-win_arm64.whl", hash = "sha256:9b4de86c8d909aed15b7011182c8cab38c8850de36e6afb1f0db22b8959e3091"}, + {file = "pywin32-308-cp313-cp313-win32.whl", hash = "sha256:1c44539a37a5b7b21d02ab34e6a4d314e0788f1690d65b48e9b0b89f31abbbed"}, + {file = "pywin32-308-cp313-cp313-win_amd64.whl", hash = "sha256:fd380990e792eaf6827fcb7e187b2b4b1cede0585e3d0c9e84201ec27b9905e4"}, + {file = "pywin32-308-cp313-cp313-win_arm64.whl", hash = "sha256:ef313c46d4c18dfb82a2431e3051ac8f112ccee1a34f29c263c583c568db63cd"}, + {file = "pywin32-308-cp37-cp37m-win32.whl", hash = "sha256:1f696ab352a2ddd63bd07430080dd598e6369152ea13a25ebcdd2f503a38f1ff"}, + {file = "pywin32-308-cp37-cp37m-win_amd64.whl", hash = "sha256:13dcb914ed4347019fbec6697a01a0aec61019c1046c2b905410d197856326a6"}, + {file = "pywin32-308-cp38-cp38-win32.whl", hash = "sha256:5794e764ebcabf4ff08c555b31bd348c9025929371763b2183172ff4708152f0"}, + {file = "pywin32-308-cp38-cp38-win_amd64.whl", hash = "sha256:3b92622e29d651c6b783e368ba7d6722b1634b8e70bd376fd7610fe1992e19de"}, + {file = "pywin32-308-cp39-cp39-win32.whl", hash = "sha256:7873ca4dc60ab3287919881a7d4f88baee4a6e639aa6962de25a98ba6b193341"}, + {file = "pywin32-308-cp39-cp39-win_amd64.whl", hash = "sha256:71b3322d949b4cc20776436a9c9ba0eeedcbc9c650daa536df63f0ff111bb920"}, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -1739,6 +1803,58 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "testcontainers" +version = "4.8.2" +description = "Python library for throwaway instances of anything that can run in a Docker container" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "testcontainers-4.8.2-py3-none-any.whl", hash = "sha256:9e19af077cd96e1957c13ee466f1f32905bc6c5bc1bc98643eb18be1a989bfb0"}, + {file = "testcontainers-4.8.2.tar.gz", hash = "sha256:dd4a6a2ea09e3c3ecd39e180b6548105929d0bb78d665ce9919cb3f8c98f9853"}, +] + +[package.dependencies] +docker = "*" +typing-extensions = "*" +urllib3 = "*" +wrapt = "*" + +[package.extras] +arangodb = ["python-arango (>=7.8,<8.0)"] +aws = ["boto3", "httpx"] +azurite = ["azure-storage-blob (>=12.19,<13.0)"] +chroma = ["chromadb-client"] +clickhouse = ["clickhouse-driver"] +cosmosdb = ["azure-cosmos"] +db2 = ["ibm_db_sa", "sqlalchemy"] +generic = ["httpx", "redis"] +google = ["google-cloud-datastore (>=2)", "google-cloud-pubsub (>=2)"] +influxdb = ["influxdb", "influxdb-client"] +k3s = ["kubernetes", "pyyaml"] +keycloak = ["python-keycloak"] +localstack = ["boto3"] +mailpit = ["cryptography"] +minio = ["minio"] +mongodb = ["pymongo"] +mssql = ["pymssql", "sqlalchemy"] +mysql = ["pymysql[rsa]", "sqlalchemy"] +nats = ["nats-py"] +neo4j = ["neo4j"] +opensearch = ["opensearch-py"] +oracle = ["oracledb", "sqlalchemy"] +oracle-free = ["oracledb", "sqlalchemy"] +qdrant = ["qdrant-client"] +rabbitmq = ["pika"] +redis = ["redis"] +registry = ["bcrypt"] +scylla = ["cassandra-driver (==3.29.1)"] +selenium = ["selenium"] +sftp = ["cryptography"] +test-module-import = ["httpx"] +trino = ["trino"] +weaviate = ["weaviate-client (>=4.5.4,<5.0.0)"] + [[package]] name = "toml" version = "0.10.2" @@ -2026,6 +2142,80 @@ files = [ {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, ] +[[package]] +name = "wrapt" +version = "1.17.0" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = ">=3.8" +files = [ + {file = "wrapt-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a0c23b8319848426f305f9cb0c98a6e32ee68a36264f45948ccf8e7d2b941f8"}, + {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1ca5f060e205f72bec57faae5bd817a1560fcfc4af03f414b08fa29106b7e2d"}, + {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e185ec6060e301a7e5f8461c86fb3640a7beb1a0f0208ffde7a65ec4074931df"}, + {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb90765dd91aed05b53cd7a87bd7f5c188fcd95960914bae0d32c5e7f899719d"}, + {file = "wrapt-1.17.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:879591c2b5ab0a7184258274c42a126b74a2c3d5a329df16d69f9cee07bba6ea"}, + {file = "wrapt-1.17.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fce6fee67c318fdfb7f285c29a82d84782ae2579c0e1b385b7f36c6e8074fffb"}, + {file = "wrapt-1.17.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0698d3a86f68abc894d537887b9bbf84d29bcfbc759e23f4644be27acf6da301"}, + {file = "wrapt-1.17.0-cp310-cp310-win32.whl", hash = "sha256:69d093792dc34a9c4c8a70e4973a3361c7a7578e9cd86961b2bbf38ca71e4e22"}, + {file = "wrapt-1.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:f28b29dc158ca5d6ac396c8e0a2ef45c4e97bb7e65522bfc04c989e6fe814575"}, + {file = "wrapt-1.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:74bf625b1b4caaa7bad51d9003f8b07a468a704e0644a700e936c357c17dd45a"}, + {file = "wrapt-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f2a28eb35cf99d5f5bd12f5dd44a0f41d206db226535b37b0c60e9da162c3ed"}, + {file = "wrapt-1.17.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:81b1289e99cf4bad07c23393ab447e5e96db0ab50974a280f7954b071d41b489"}, + {file = "wrapt-1.17.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f2939cd4a2a52ca32bc0b359015718472d7f6de870760342e7ba295be9ebaf9"}, + {file = "wrapt-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6a9653131bda68a1f029c52157fd81e11f07d485df55410401f745007bd6d339"}, + {file = "wrapt-1.17.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4e4b4385363de9052dac1a67bfb535c376f3d19c238b5f36bddc95efae15e12d"}, + {file = "wrapt-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bdf62d25234290db1837875d4dceb2151e4ea7f9fff2ed41c0fde23ed542eb5b"}, + {file = "wrapt-1.17.0-cp311-cp311-win32.whl", hash = "sha256:5d8fd17635b262448ab8f99230fe4dac991af1dabdbb92f7a70a6afac8a7e346"}, + {file = "wrapt-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:92a3d214d5e53cb1db8b015f30d544bc9d3f7179a05feb8f16df713cecc2620a"}, + {file = "wrapt-1.17.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:89fc28495896097622c3fc238915c79365dd0ede02f9a82ce436b13bd0ab7569"}, + {file = "wrapt-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:875d240fdbdbe9e11f9831901fb8719da0bd4e6131f83aa9f69b96d18fae7504"}, + {file = "wrapt-1.17.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5ed16d95fd142e9c72b6c10b06514ad30e846a0d0917ab406186541fe68b451"}, + {file = "wrapt-1.17.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18b956061b8db634120b58f668592a772e87e2e78bc1f6a906cfcaa0cc7991c1"}, + {file = "wrapt-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:daba396199399ccabafbfc509037ac635a6bc18510ad1add8fd16d4739cdd106"}, + {file = "wrapt-1.17.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4d63f4d446e10ad19ed01188d6c1e1bb134cde8c18b0aa2acfd973d41fcc5ada"}, + {file = "wrapt-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8a5e7cc39a45fc430af1aefc4d77ee6bad72c5bcdb1322cfde852c15192b8bd4"}, + {file = "wrapt-1.17.0-cp312-cp312-win32.whl", hash = "sha256:0a0a1a1ec28b641f2a3a2c35cbe86c00051c04fffcfcc577ffcdd707df3f8635"}, + {file = "wrapt-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:3c34f6896a01b84bab196f7119770fd8466c8ae3dfa73c59c0bb281e7b588ce7"}, + {file = "wrapt-1.17.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:714c12485aa52efbc0fc0ade1e9ab3a70343db82627f90f2ecbc898fdf0bb181"}, + {file = "wrapt-1.17.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da427d311782324a376cacb47c1a4adc43f99fd9d996ffc1b3e8529c4074d393"}, + {file = "wrapt-1.17.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba1739fb38441a27a676f4de4123d3e858e494fac05868b7a281c0a383c098f4"}, + {file = "wrapt-1.17.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e711fc1acc7468463bc084d1b68561e40d1eaa135d8c509a65dd534403d83d7b"}, + {file = "wrapt-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:140ea00c87fafc42739bd74a94a5a9003f8e72c27c47cd4f61d8e05e6dec8721"}, + {file = "wrapt-1.17.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:73a96fd11d2b2e77d623a7f26e004cc31f131a365add1ce1ce9a19e55a1eef90"}, + {file = "wrapt-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0b48554952f0f387984da81ccfa73b62e52817a4386d070c75e4db7d43a28c4a"}, + {file = "wrapt-1.17.0-cp313-cp313-win32.whl", hash = "sha256:498fec8da10e3e62edd1e7368f4b24aa362ac0ad931e678332d1b209aec93045"}, + {file = "wrapt-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd136bb85f4568fffca995bd3c8d52080b1e5b225dbf1c2b17b66b4c5fa02838"}, + {file = "wrapt-1.17.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:17fcf043d0b4724858f25b8826c36e08f9fb2e475410bece0ec44a22d533da9b"}, + {file = "wrapt-1.17.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4a557d97f12813dc5e18dad9fa765ae44ddd56a672bb5de4825527c847d6379"}, + {file = "wrapt-1.17.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0229b247b0fc7dee0d36176cbb79dbaf2a9eb7ecc50ec3121f40ef443155fb1d"}, + {file = "wrapt-1.17.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8425cfce27b8b20c9b89d77fb50e368d8306a90bf2b6eef2cdf5cd5083adf83f"}, + {file = "wrapt-1.17.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9c900108df470060174108012de06d45f514aa4ec21a191e7ab42988ff42a86c"}, + {file = "wrapt-1.17.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:4e547b447073fc0dbfcbff15154c1be8823d10dab4ad401bdb1575e3fdedff1b"}, + {file = "wrapt-1.17.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:914f66f3b6fc7b915d46c1cc424bc2441841083de01b90f9e81109c9759e43ab"}, + {file = "wrapt-1.17.0-cp313-cp313t-win32.whl", hash = "sha256:a4192b45dff127c7d69b3bdfb4d3e47b64179a0b9900b6351859f3001397dabf"}, + {file = "wrapt-1.17.0-cp313-cp313t-win_amd64.whl", hash = "sha256:4f643df3d4419ea3f856c5c3f40fec1d65ea2e89ec812c83f7767c8730f9827a"}, + {file = "wrapt-1.17.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:69c40d4655e078ede067a7095544bcec5a963566e17503e75a3a3e0fe2803b13"}, + {file = "wrapt-1.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f495b6754358979379f84534f8dd7a43ff8cff2558dcdea4a148a6e713a758f"}, + {file = "wrapt-1.17.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:baa7ef4e0886a6f482e00d1d5bcd37c201b383f1d314643dfb0367169f94f04c"}, + {file = "wrapt-1.17.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8fc931382e56627ec4acb01e09ce66e5c03c384ca52606111cee50d931a342d"}, + {file = "wrapt-1.17.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8f8909cdb9f1b237786c09a810e24ee5e15ef17019f7cecb207ce205b9b5fcce"}, + {file = "wrapt-1.17.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:ad47b095f0bdc5585bced35bd088cbfe4177236c7df9984b3cc46b391cc60627"}, + {file = "wrapt-1.17.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:948a9bd0fb2c5120457b07e59c8d7210cbc8703243225dbd78f4dfc13c8d2d1f"}, + {file = "wrapt-1.17.0-cp38-cp38-win32.whl", hash = "sha256:5ae271862b2142f4bc687bdbfcc942e2473a89999a54231aa1c2c676e28f29ea"}, + {file = "wrapt-1.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:f335579a1b485c834849e9075191c9898e0731af45705c2ebf70e0cd5d58beed"}, + {file = "wrapt-1.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d751300b94e35b6016d4b1e7d0e7bbc3b5e1751e2405ef908316c2a9024008a1"}, + {file = "wrapt-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7264cbb4a18dc4acfd73b63e4bcfec9c9802614572025bdd44d0721983fc1d9c"}, + {file = "wrapt-1.17.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33539c6f5b96cf0b1105a0ff4cf5db9332e773bb521cc804a90e58dc49b10578"}, + {file = "wrapt-1.17.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c30970bdee1cad6a8da2044febd824ef6dc4cc0b19e39af3085c763fdec7de33"}, + {file = "wrapt-1.17.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bc7f729a72b16ee21795a943f85c6244971724819819a41ddbaeb691b2dd85ad"}, + {file = "wrapt-1.17.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6ff02a91c4fc9b6a94e1c9c20f62ea06a7e375f42fe57587f004d1078ac86ca9"}, + {file = "wrapt-1.17.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2dfb7cff84e72e7bf975b06b4989477873dcf160b2fd89959c629535df53d4e0"}, + {file = "wrapt-1.17.0-cp39-cp39-win32.whl", hash = "sha256:2399408ac33ffd5b200480ee858baa58d77dd30e0dd0cab6a8a9547135f30a88"}, + {file = "wrapt-1.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:4f763a29ee6a20c529496a20a7bcb16a73de27f5da6a843249c7047daf135977"}, + {file = "wrapt-1.17.0-py3-none-any.whl", hash = "sha256:d2c63b93548eda58abf5188e505ffed0229bf675f7c3090f8e36ad55b8cbc371"}, + {file = "wrapt-1.17.0.tar.gz", hash = "sha256:16187aa2317c731170a88ef35e8937ae0f533c402872c1ee5e6d079fcf320801"}, +] + [[package]] name = "yamllint" version = "1.35.1" @@ -2071,4 +2261,4 @@ tests = ["Jinja2", "pytest", "pyyaml", "rich"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bd7225275eeae26e41660131b9802cc569e73586b16029bced8c7633725b2c95" +content-hash = "03566b615c5d853f4a90a262c472fb5e0260108bd1a99cd0fa902eb93aeef2bb" diff --git a/pyproject.toml b/pyproject.toml index ef4bdb40..49e7c11e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "infrahub-sdk" -version = "1.2.0-dev0" +version = "1.3.0" requires-python = ">=3.9" [tool.poetry] name = "infrahub-sdk" -version = "1.2.0" +version = "1.3.0" description = "Python Client to interact with Infrahub" authors = ["OpsMill "] readme = "README.md" @@ -71,6 +71,8 @@ pytest-xdist = "^3.3.1" types-python-slugify = "^8.0.0.3" invoke = "^2.2.0" towncrier = "^24.8.0" +infrahub-testcontainers = "^1.1.0b2" +astroid = "~3.1" [tool.poetry.extras] ctl = ["Jinja2", "numpy", "pyarrow", "pyyaml", "rich", "toml", "typer"] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index aa2e6ffd..0596ce23 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,565 +1,563 @@ -import asyncio +# import asyncio import os -from typing import Any, Optional - -import httpx -import pytest -import ujson -from fastapi.testclient import TestClient -from infrahub import config -from infrahub.components import ComponentType -from infrahub.core.initialization import first_time_initialization, initialization -from infrahub.core.node import Node -from infrahub.core.utils import delete_all_nodes -from infrahub.database import InfrahubDatabase, get_db -from infrahub.lock import initialize_lock -from infrahub.message_bus import InfrahubMessage -from infrahub.message_bus.types import MessageTTL -from infrahub.services.adapters.message_bus import InfrahubMessageBus - -from infrahub_sdk.schema import NodeSchema, SchemaRoot -from infrahub_sdk.types import HTTPMethod + +# import httpx +# import pytest +# import ujson +# from fastapi.testclient import TestClient +# from infrahub import config +# from infrahub.components import ComponentType +# from infrahub.core.initialization import first_time_initialization, initialization +# from infrahub.core.node import Node +# from infrahub.core.utils import delete_all_nodes +# from infrahub.database import InfrahubDatabase, get_db +# from infrahub.lock import initialize_lock +# from infrahub.message_bus import InfrahubMessage +# from infrahub.message_bus.types import MessageTTL +# from infrahub.services.adapters.message_bus import InfrahubMessageBus +# from infrahub_sdk.schema import NodeSchema, SchemaRoot +# from infrahub_sdk.types import HTTPMethod from infrahub_sdk.utils import str_to_bool BUILD_NAME = os.environ.get("INFRAHUB_BUILD_NAME", "infrahub") TEST_IN_DOCKER = str_to_bool(os.environ.get("INFRAHUB_TEST_IN_DOCKER", "false")) -@pytest.fixture(scope="session", autouse=True) -def add_tracker(): - os.environ["PYTEST_RUNNING"] = "true" - - -# pylint: disable=redefined-outer-name -class InfrahubTestClient(TestClient): - def _request( - self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None - ) -> httpx.Response: - content = None - if payload: - content = str(ujson.dumps(payload)).encode("UTF-8") - with self as client: - return client.request( - method=method.value, - url=url, - headers=headers, - timeout=timeout, - content=content, - ) - - async def async_request( - self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None - ) -> httpx.Response: - return self._request(url=url, method=method, headers=headers, timeout=timeout, payload=payload) - - def sync_request( - self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None - ) -> httpx.Response: - return self._request(url=url, method=method, headers=headers, timeout=timeout, payload=payload) - - -@pytest.fixture(scope="session") -def event_loop(): - """Overrides pytest default function scoped event loop""" - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() - - -@pytest.fixture(scope="module", autouse=True) -def execute_before_any_test(worker_id, tmpdir_factory): - config.load_and_exit() - - config.SETTINGS.storage.driver = config.StorageDriver.FileSystemStorage - - if TEST_IN_DOCKER: - try: - db_id = int(worker_id[2]) + 1 - except (ValueError, IndexError): - db_id = 1 - config.SETTINGS.cache.address = f"{BUILD_NAME}-cache-1" - config.SETTINGS.database.address = f"{BUILD_NAME}-database-{db_id}" - config.SETTINGS.storage.local = config.FileSystemStorageSettings(path="/opt/infrahub/storage") - else: - storage_dir = tmpdir_factory.mktemp("storage") - config.SETTINGS.storage.local.path_ = str(storage_dir) - - config.SETTINGS.broker.enable = False - config.SETTINGS.cache.enable = True - config.SETTINGS.miscellaneous.start_background_runner = False - config.SETTINGS.security.secret_key = "4e26b3d9-b84f-42c9-a03f-fee3ada3b2fa" - config.SETTINGS.main.internal_address = "http://mock" - config.OVERRIDE.message_bus = BusRecorder() - - initialize_lock() - - -@pytest.fixture(scope="module") -async def db() -> InfrahubDatabase: - driver = InfrahubDatabase(driver=await get_db(retry=1)) - - yield driver - - await driver.close() - - -@pytest.fixture(scope="module") -async def init_db_base(db: InfrahubDatabase): - await delete_all_nodes(db=db) - await first_time_initialization(db=db) - await initialization(db=db) - - -@pytest.fixture(scope="module") -async def builtin_org_schema() -> SchemaRoot: - SCHEMA = { - "version": "1.0", - "nodes": [ - { - "name": "Organization", - "namespace": "Test", - "description": "An organization represent a legal entity, a company.", - "include_in_menu": True, - "label": "Organization", - "icon": "mdi:domain", - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["label__value"], - "branch": "aware", - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "label", "kind": "Text", "optional": True}, - {"name": "description", "kind": "Text", "optional": True}, - ], - "relationships": [ - { - "name": "tags", - "peer": "BuiltinTag", - "kind": "Attribute", - "optional": True, - "cardinality": "many", - }, - ], - }, - { - "name": "Status", - "namespace": "Builtin", - "description": "Represent the status of an object: active, maintenance", - "include_in_menu": True, - "icon": "mdi:list-status", - "label": "Status", - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["label__value"], - "branch": "aware", - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "label", "kind": "Text", "optional": True}, - {"name": "description", "kind": "Text", "optional": True}, - ], - }, - { - "name": "Role", - "namespace": "Builtin", - "description": "Represent the role of an object", - "include_in_menu": True, - "icon": "mdi:ballot", - "label": "Role", - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["label__value"], - "branch": "aware", - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "label", "kind": "Text", "optional": True}, - {"name": "description", "kind": "Text", "optional": True}, - ], - }, - { - "name": "Location", - "namespace": "Builtin", - "description": "A location represent a physical element: a building, a site, a city", - "include_in_menu": True, - "icon": "mdi:map-marker-radius-outline", - "label": "Location", - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["name__value"], - "branch": "aware", - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "description", "kind": "Text", "optional": True}, - {"name": "type", "kind": "Text"}, - ], - "relationships": [ - { - "name": "tags", - "peer": "BuiltinTag", - "kind": "Attribute", - "optional": True, - "cardinality": "many", - }, - ], - }, - { - "name": "Criticality", - "namespace": "Builtin", - "description": "Level of criticality expressed from 1 to 10.", - "include_in_menu": True, - "icon": "mdi:alert-octagon-outline", - "label": "Criticality", - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["name__value"], - "branch": "aware", - "attributes": [ - {"name": "name", "kind": "Text", "unique": True}, - {"name": "level", "kind": "Number", "enum": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, - {"name": "description", "kind": "Text", "optional": True}, - ], - }, - ], - } - - return SCHEMA - - -@pytest.fixture -async def location_schema() -> NodeSchema: - data = { - "name": "Location", - "namespace": "Builtin", - "default_filter": "name__value", - "attributes": [ - {"name": "name", "kind": "String", "unique": True}, - {"name": "description", "kind": "String", "optional": True}, - {"name": "type", "kind": "String"}, - ], - "relationships": [ - { - "name": "tags", - "peer": "BuiltinTag", - "optional": True, - "cardinality": "many", - }, - { - "name": "primary_tag", - "peer": "BultinTag", - "optional": True, - "cardinality": "one", - }, - ], - } - return NodeSchema(**data) # type: ignore - - -@pytest.fixture -async def location_cdg(db: InfrahubDatabase, tag_blue: Node, tag_red: Node) -> Node: - obj = await Node.init(schema="BuiltinLocation", db=db) - await obj.new(db=db, name="cdg01", type="SITE", tags=[tag_blue, tag_red]) - await obj.save(db=db) - return obj - - -@pytest.fixture -async def tag_blue(db: InfrahubDatabase) -> Node: - obj = await Node.init(schema="BuiltinTag", db=db) - await obj.new(db=db, name="Blue") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def tag_red(db: InfrahubDatabase) -> Node: - obj = await Node.init(schema="BuiltinTag", db=db) - await obj.new(db=db, name="Red") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def tag_green(db: InfrahubDatabase) -> Node: - obj = await Node.init(schema="BuiltinTag", db=db) - await obj.new(db=db, name="Green") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def first_account(db: InfrahubDatabase) -> Node: - obj = await Node.init(db=db, schema="CoreAccount") - await obj.new(db=db, name="First Account", account_type="Git", password="TestPassword123") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def second_account(db: InfrahubDatabase) -> Node: - obj = await Node.init(db=db, schema="CoreAccount") - await obj.new(db=db, name="Second Account", account_type="Git", password="TestPassword123") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def repo01(db: InfrahubDatabase) -> Node: - obj = await Node.init(db=db, schema="CoreRepository") - await obj.new(db=db, name="repo01", location="https://github.com/my/repo.git") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def repo99(db: InfrahubDatabase) -> Node: - obj = await Node.init(db=db, schema="CoreRepository") - await obj.new(db=db, name="repo99", location="https://github.com/my/repo99.git") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def gqlquery01(db: InfrahubDatabase) -> Node: - obj = await Node.init(db=db, schema="CoreGraphQLQuery") - await obj.new(db=db, name="query01", query="query { device { name { value }}}") - await obj.save(db=db) - return obj - - -@pytest.fixture -async def gqlquery02(db: InfrahubDatabase, repo01: Node, tag_blue: Node, tag_red: Node) -> Node: - obj = await Node.init(db=db, schema="CoreGraphQLQuery") - await obj.new( - db=db, - name="query02", - query="query { CoreRepository { edges { node { name { value }}}}}", - repository=repo01, - tags=[tag_blue, tag_red], - ) - await obj.save(db=db) - return obj - - -@pytest.fixture -async def gqlquery03(db: InfrahubDatabase, repo01: Node, tag_blue: Node, tag_red: Node) -> Node: - obj = await Node.init(db=db, schema="CoreGraphQLQuery") - await obj.new( - db=db, - name="query03", - query="query { CoreRepository { edges { node { name { value }}}}}", - repository=repo01, - tags=[tag_blue, tag_red], - ) - await obj.save(db=db) - return obj - - -@pytest.fixture -async def schema_extension_01() -> dict[str, Any]: - return { - "version": "1.0", - "nodes": [ - { - "name": "Rack", - "namespace": "Infra", - "description": "A Rack represents a physical two- or four-post equipment rack in which devices can be installed.", - "label": "Rack", - "default_filter": "name__value", - "display_labels": ["name__value"], - "attributes": [ - {"name": "name", "kind": "Text"}, - {"name": "description", "kind": "Text", "optional": True}, - ], - "relationships": [ - { - "name": "tags", - "peer": "BuiltinTag", - "optional": True, - "cardinality": "many", - "kind": "Attribute", - }, - ], - } - ], - "extensions": { - "nodes": [ - { - "kind": "BuiltinTag", - "relationships": [ - { - "name": "racks", - "peer": "InfraRack", - "optional": True, - "cardinality": "many", - "kind": "Generic", - } - ], - } - ] - }, - } - - -@pytest.fixture -async def schema_extension_02() -> dict[str, Any]: - return { - "version": "1.0", - "nodes": [ - { - "name": "Contract", - "namespace": "Procurement", - "description": "Generic Contract", - "label": "Contract", - "display_labels": ["contract_ref__value"], - "order_by": ["contract_ref__value"], - "attributes": [ - { - "name": "contract_ref", - "label": "Contract Reference", - "kind": "Text", - "unique": True, - }, - {"name": "description", "kind": "Text", "optional": True}, - ], - "relationships": [ - { - "name": "tags", - "peer": "BuiltinTag", - "optional": True, - "cardinality": "many", - "kind": "Attribute", - }, - ], - } - ], - "extensions": { - "nodes": [ - { - "kind": "BuiltinTag", - "relationships": [ - { - "name": "contracts", - "peer": "ProcurementContract", - "optional": True, - "cardinality": "many", - "kind": "Generic", - } - ], - } - ] - }, - } - - -@pytest.fixture(scope="module") -async def ipam_schema() -> SchemaRoot: - SCHEMA = { - "version": "1.0", - "nodes": [ - { - "name": "IPPrefix", - "namespace": "Ipam", - "include_in_menu": False, - "inherit_from": ["BuiltinIPPrefix"], - "description": "IPv4 or IPv6 network", - "icon": "mdi:ip-network", - "label": "IP Prefix", - }, - { - "name": "IPAddress", - "namespace": "Ipam", - "include_in_menu": False, - "inherit_from": ["BuiltinIPAddress"], - "description": "IP Address", - "icon": "mdi:ip-outline", - "label": "IP Address", - }, - { - "name": "Device", - "namespace": "Infra", - "label": "Device", - "human_friendly_id": ["name__value"], - "order_by": ["name__value"], - "display_labels": ["name__value"], - "attributes": [{"name": "name", "kind": "Text", "unique": True}], - "relationships": [ - { - "name": "primary_address", - "peer": "IpamIPAddress", - "label": "Primary IP Address", - "optional": True, - "cardinality": "one", - "kind": "Attribute", - } - ], - }, - ], - } - - return SCHEMA - - -@pytest.fixture(scope="module") -async def hierarchical_schema() -> dict: - schema = { - "version": "1.0", - "generics": [ - { - "name": "Generic", - "namespace": "Location", - "description": "Generic hierarchical location", - "label": "Location", - "hierarchical": True, - "human_friendly_id": ["name__value"], - "include_in_menu": True, - "attributes": [ - {"name": "name", "kind": "Text", "unique": True, "order_weight": 900}, - ], - } - ], - "nodes": [ - { - "name": "Country", - "namespace": "Location", - "description": "A country within a continent.", - "inherit_from": ["LocationGeneric"], - "generate_profile": False, - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["name__value"], - "children": "LocationSite", - "attributes": [{"name": "shortname", "kind": "Text"}], - }, - { - "name": "Site", - "namespace": "Location", - "description": "A site within a country.", - "inherit_from": ["LocationGeneric"], - "default_filter": "name__value", - "order_by": ["name__value"], - "display_labels": ["name__value"], - "children": "", - "parent": "LocationCountry", - "attributes": [{"name": "shortname", "kind": "Text"}], - }, - ], - } - return schema - - -class BusRecorder(InfrahubMessageBus): - def __init__(self, component_type: Optional[ComponentType] = None): - self.messages: list[InfrahubMessage] = [] - self.messages_per_routing_key: dict[str, list[InfrahubMessage]] = {} - - async def publish( - self, message: InfrahubMessage, routing_key: str, delay: Optional[MessageTTL] = None, is_retry: bool = False - ) -> None: - self.messages.append(message) - if routing_key not in self.messages_per_routing_key: - self.messages_per_routing_key[routing_key] = [] - self.messages_per_routing_key[routing_key].append(message) - - @property - def seen_routing_keys(self) -> list[str]: - return list(self.messages_per_routing_key.keys()) +# @pytest.fixture(scope="session", autouse=True) +# def add_tracker(): +# os.environ["PYTEST_RUNNING"] = "true" + + +# # pylint: disable=redefined-outer-name +# class InfrahubTestClient(TestClient): +# def _request( +# self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None +# ) -> httpx.Response: +# content = None +# if payload: +# content = str(ujson.dumps(payload)).encode("UTF-8") +# with self as client: +# return client.request( +# method=method.value, +# url=url, +# headers=headers, +# timeout=timeout, +# content=content, +# ) + +# async def async_request( +# self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None +# ) -> httpx.Response: +# return self._request(url=url, method=method, headers=headers, timeout=timeout, payload=payload) + +# def sync_request( +# self, url: str, method: HTTPMethod, headers: dict[str, Any], timeout: int, payload: Optional[dict] = None +# ) -> httpx.Response: +# return self._request(url=url, method=method, headers=headers, timeout=timeout, payload=payload) + + +# @pytest.fixture(scope="session") +# def event_loop(): +# """Overrides pytest default function scoped event loop""" +# policy = asyncio.get_event_loop_policy() +# loop = policy.new_event_loop() +# yield loop +# loop.close() + + +# @pytest.fixture(scope="module", autouse=True) +# def execute_before_any_test(worker_id, tmpdir_factory): +# config.load_and_exit() + +# config.SETTINGS.storage.driver = config.StorageDriver.FileSystemStorage + +# if TEST_IN_DOCKER: +# try: +# db_id = int(worker_id[2]) + 1 +# except (ValueError, IndexError): +# db_id = 1 +# config.SETTINGS.cache.address = f"{BUILD_NAME}-cache-1" +# config.SETTINGS.database.address = f"{BUILD_NAME}-database-{db_id}" +# config.SETTINGS.storage.local = config.FileSystemStorageSettings(path="/opt/infrahub/storage") +# else: +# storage_dir = tmpdir_factory.mktemp("storage") +# config.SETTINGS.storage.local.path_ = str(storage_dir) + +# config.SETTINGS.broker.enable = False +# config.SETTINGS.cache.enable = True +# config.SETTINGS.miscellaneous.start_background_runner = False +# config.SETTINGS.security.secret_key = "4e26b3d9-b84f-42c9-a03f-fee3ada3b2fa" +# config.SETTINGS.main.internal_address = "http://mock" +# config.OVERRIDE.message_bus = BusRecorder() + +# initialize_lock() + + +# @pytest.fixture(scope="module") +# async def db() -> InfrahubDatabase: +# driver = InfrahubDatabase(driver=await get_db(retry=1)) + +# yield driver + +# await driver.close() + + +# @pytest.fixture(scope="module") +# async def init_db_base(db: InfrahubDatabase): +# await delete_all_nodes(db=db) +# await first_time_initialization(db=db) +# await initialization(db=db) + + +# @pytest.fixture(scope="module") +# async def builtin_org_schema() -> SchemaRoot: +# SCHEMA = { +# "version": "1.0", +# "nodes": [ +# { +# "name": "Organization", +# "namespace": "Test", +# "description": "An organization represent a legal entity, a company.", +# "include_in_menu": True, +# "label": "Organization", +# "icon": "mdi:domain", +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["label__value"], +# "branch": "aware", +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True}, +# {"name": "label", "kind": "Text", "optional": True}, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# "relationships": [ +# { +# "name": "tags", +# "peer": "BuiltinTag", +# "kind": "Attribute", +# "optional": True, +# "cardinality": "many", +# }, +# ], +# }, +# { +# "name": "Status", +# "namespace": "Builtin", +# "description": "Represent the status of an object: active, maintenance", +# "include_in_menu": True, +# "icon": "mdi:list-status", +# "label": "Status", +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["label__value"], +# "branch": "aware", +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True}, +# {"name": "label", "kind": "Text", "optional": True}, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# }, +# { +# "name": "Role", +# "namespace": "Builtin", +# "description": "Represent the role of an object", +# "include_in_menu": True, +# "icon": "mdi:ballot", +# "label": "Role", +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["label__value"], +# "branch": "aware", +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True}, +# {"name": "label", "kind": "Text", "optional": True}, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# }, +# { +# "name": "Location", +# "namespace": "Builtin", +# "description": "A location represent a physical element: a building, a site, a city", +# "include_in_menu": True, +# "icon": "mdi:map-marker-radius-outline", +# "label": "Location", +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["name__value"], +# "branch": "aware", +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True}, +# {"name": "description", "kind": "Text", "optional": True}, +# {"name": "type", "kind": "Text"}, +# ], +# "relationships": [ +# { +# "name": "tags", +# "peer": "BuiltinTag", +# "kind": "Attribute", +# "optional": True, +# "cardinality": "many", +# }, +# ], +# }, +# { +# "name": "Criticality", +# "namespace": "Builtin", +# "description": "Level of criticality expressed from 1 to 10.", +# "include_in_menu": True, +# "icon": "mdi:alert-octagon-outline", +# "label": "Criticality", +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["name__value"], +# "branch": "aware", +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True}, +# {"name": "level", "kind": "Number", "enum": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# }, +# ], +# } + +# return SCHEMA + + +# @pytest.fixture +# async def location_schema() -> NodeSchema: +# data = { +# "name": "Location", +# "namespace": "Builtin", +# "default_filter": "name__value", +# "attributes": [ +# {"name": "name", "kind": "String", "unique": True}, +# {"name": "description", "kind": "String", "optional": True}, +# {"name": "type", "kind": "String"}, +# ], +# "relationships": [ +# { +# "name": "tags", +# "peer": "BuiltinTag", +# "optional": True, +# "cardinality": "many", +# }, +# { +# "name": "primary_tag", +# "peer": "BultinTag", +# "optional": True, +# "cardinality": "one", +# }, +# ], +# } +# return NodeSchema(**data) # type: ignore + + +# @pytest.fixture +# async def location_cdg(db: InfrahubDatabase, tag_blue: Node, tag_red: Node) -> Node: +# obj = await Node.init(schema="BuiltinLocation", db=db) +# await obj.new(db=db, name="cdg01", type="SITE", tags=[tag_blue, tag_red]) +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def tag_blue(db: InfrahubDatabase) -> Node: +# obj = await Node.init(schema="BuiltinTag", db=db) +# await obj.new(db=db, name="Blue") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def tag_red(db: InfrahubDatabase) -> Node: +# obj = await Node.init(schema="BuiltinTag", db=db) +# await obj.new(db=db, name="Red") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def tag_green(db: InfrahubDatabase) -> Node: +# obj = await Node.init(schema="BuiltinTag", db=db) +# await obj.new(db=db, name="Green") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def first_account(db: InfrahubDatabase) -> Node: +# obj = await Node.init(db=db, schema="CoreAccount") +# await obj.new(db=db, name="First Account", account_type="Git", password="TestPassword123") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def second_account(db: InfrahubDatabase) -> Node: +# obj = await Node.init(db=db, schema="CoreAccount") +# await obj.new(db=db, name="Second Account", account_type="Git", password="TestPassword123") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def repo01(db: InfrahubDatabase) -> Node: +# obj = await Node.init(db=db, schema="CoreRepository") +# await obj.new(db=db, name="repo01", location="https://github.com/my/repo.git") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def repo99(db: InfrahubDatabase) -> Node: +# obj = await Node.init(db=db, schema="CoreRepository") +# await obj.new(db=db, name="repo99", location="https://github.com/my/repo99.git") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def gqlquery01(db: InfrahubDatabase) -> Node: +# obj = await Node.init(db=db, schema="CoreGraphQLQuery") +# await obj.new(db=db, name="query01", query="query { device { name { value }}}") +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def gqlquery02(db: InfrahubDatabase, repo01: Node, tag_blue: Node, tag_red: Node) -> Node: +# obj = await Node.init(db=db, schema="CoreGraphQLQuery") +# await obj.new( +# db=db, +# name="query02", +# query="query { CoreRepository { edges { node { name { value }}}}}", +# repository=repo01, +# tags=[tag_blue, tag_red], +# ) +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def gqlquery03(db: InfrahubDatabase, repo01: Node, tag_blue: Node, tag_red: Node) -> Node: +# obj = await Node.init(db=db, schema="CoreGraphQLQuery") +# await obj.new( +# db=db, +# name="query03", +# query="query { CoreRepository { edges { node { name { value }}}}}", +# repository=repo01, +# tags=[tag_blue, tag_red], +# ) +# await obj.save(db=db) +# return obj + + +# @pytest.fixture +# async def schema_extension_01() -> dict[str, Any]: +# return { +# "version": "1.0", +# "nodes": [ +# { +# "name": "Rack", +# "namespace": "Infra", +# "description": "A Rack represents a physical two- or four-post equipment rack in which devices can be installed.", +# "label": "Rack", +# "default_filter": "name__value", +# "display_labels": ["name__value"], +# "attributes": [ +# {"name": "name", "kind": "Text"}, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# "relationships": [ +# { +# "name": "tags", +# "peer": "BuiltinTag", +# "optional": True, +# "cardinality": "many", +# "kind": "Attribute", +# }, +# ], +# } +# ], +# "extensions": { +# "nodes": [ +# { +# "kind": "BuiltinTag", +# "relationships": [ +# { +# "name": "racks", +# "peer": "InfraRack", +# "optional": True, +# "cardinality": "many", +# "kind": "Generic", +# } +# ], +# } +# ] +# }, +# } + + +# @pytest.fixture +# async def schema_extension_02() -> dict[str, Any]: +# return { +# "version": "1.0", +# "nodes": [ +# { +# "name": "Contract", +# "namespace": "Procurement", +# "description": "Generic Contract", +# "label": "Contract", +# "display_labels": ["contract_ref__value"], +# "order_by": ["contract_ref__value"], +# "attributes": [ +# { +# "name": "contract_ref", +# "label": "Contract Reference", +# "kind": "Text", +# "unique": True, +# }, +# {"name": "description", "kind": "Text", "optional": True}, +# ], +# "relationships": [ +# { +# "name": "tags", +# "peer": "BuiltinTag", +# "optional": True, +# "cardinality": "many", +# "kind": "Attribute", +# }, +# ], +# } +# ], +# "extensions": { +# "nodes": [ +# { +# "kind": "BuiltinTag", +# "relationships": [ +# { +# "name": "contracts", +# "peer": "ProcurementContract", +# "optional": True, +# "cardinality": "many", +# "kind": "Generic", +# } +# ], +# } +# ] +# }, +# } + + +# @pytest.fixture(scope="module") +# async def ipam_schema() -> SchemaRoot: +# SCHEMA = { +# "version": "1.0", +# "nodes": [ +# { +# "name": "IPPrefix", +# "namespace": "Ipam", +# "include_in_menu": False, +# "inherit_from": ["BuiltinIPPrefix"], +# "description": "IPv4 or IPv6 network", +# "icon": "mdi:ip-network", +# "label": "IP Prefix", +# }, +# { +# "name": "IPAddress", +# "namespace": "Ipam", +# "include_in_menu": False, +# "inherit_from": ["BuiltinIPAddress"], +# "description": "IP Address", +# "icon": "mdi:ip-outline", +# "label": "IP Address", +# }, +# { +# "name": "Device", +# "namespace": "Infra", +# "label": "Device", +# "human_friendly_id": ["name__value"], +# "order_by": ["name__value"], +# "display_labels": ["name__value"], +# "attributes": [{"name": "name", "kind": "Text", "unique": True}], +# "relationships": [ +# { +# "name": "primary_address", +# "peer": "IpamIPAddress", +# "label": "Primary IP Address", +# "optional": True, +# "cardinality": "one", +# "kind": "Attribute", +# } +# ], +# }, +# ], +# } + +# return SCHEMA + + +# @pytest.fixture(scope="module") +# async def hierarchical_schema() -> dict: +# schema = { +# "version": "1.0", +# "generics": [ +# { +# "name": "Generic", +# "namespace": "Location", +# "description": "Generic hierarchical location", +# "label": "Location", +# "hierarchical": True, +# "human_friendly_id": ["name__value"], +# "include_in_menu": True, +# "attributes": [ +# {"name": "name", "kind": "Text", "unique": True, "order_weight": 900}, +# ], +# } +# ], +# "nodes": [ +# { +# "name": "Country", +# "namespace": "Location", +# "description": "A country within a continent.", +# "inherit_from": ["LocationGeneric"], +# "generate_profile": False, +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["name__value"], +# "children": "LocationSite", +# "attributes": [{"name": "shortname", "kind": "Text"}], +# }, +# { +# "name": "Site", +# "namespace": "Location", +# "description": "A site within a country.", +# "inherit_from": ["LocationGeneric"], +# "default_filter": "name__value", +# "order_by": ["name__value"], +# "display_labels": ["name__value"], +# "children": "", +# "parent": "LocationCountry", +# "attributes": [{"name": "shortname", "kind": "Text"}], +# }, +# ], +# } +# return schema + + +# class BusRecorder(InfrahubMessageBus): +# def __init__(self, component_type: Optional[ComponentType] = None): +# self.messages: list[InfrahubMessage] = [] +# self.messages_per_routing_key: dict[str, list[InfrahubMessage]] = {} + +# async def publish( +# self, message: InfrahubMessage, routing_key: str, delay: Optional[MessageTTL] = None, is_retry: bool = False +# ) -> None: +# self.messages.append(message) +# if routing_key not in self.messages_per_routing_key: +# self.messages_per_routing_key[routing_key] = [] +# self.messages_per_routing_key[routing_key].append(message) + +# @property +# def seen_routing_keys(self) -> list[str]: +# return list(self.messages_per_routing_key.keys()) diff --git a/tests/integration/test_node.py b/tests/integration/test_node.py index b03100dc..7421cacf 100644 --- a/tests/integration/test_node.py +++ b/tests/integration/test_node.py @@ -1,403 +1,361 @@ import pytest -from infrahub.core.manager import NodeManager -from infrahub.core.node import Node -from infrahub.database import InfrahubDatabase -from infrahub.server import app -from infrahub_sdk import Config, InfrahubClient -from infrahub_sdk.exceptions import NodeNotFoundError, UninitializedError +from infrahub_sdk import InfrahubClient +from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode - -from .conftest import InfrahubTestClient +from infrahub_sdk.schema import NodeSchema, NodeSchemaAPI, SchemaRoot +from infrahub_sdk.testing.docker import TestInfrahubDockerClient +from infrahub_sdk.testing.schemas.car_person import TESTING_MANUFACTURER, SchemaCarPerson # pylint: disable=unused-argument -class TestInfrahubNode: - @pytest.fixture(scope="class") - async def test_client(self): - return InfrahubTestClient(app) - - @pytest.fixture - async def client(self, test_client): - config = Config(username="admin", password="infrahub", requester=test_client.async_request) - return InfrahubClient(config=config) - +class TestInfrahubNode(TestInfrahubDockerClient, SchemaCarPerson): @pytest.fixture(scope="class") - async def load_builtin_schema(self, db: InfrahubDatabase, test_client: InfrahubTestClient, builtin_org_schema): - config = Config(username="admin", password="infrahub", requester=test_client.async_request) - client = InfrahubClient(config=config) - response = await client.schema.load(schemas=[builtin_org_schema]) - assert not response.errors + def infrahub_version(self) -> str: + return "1.0.10" @pytest.fixture(scope="class") - async def load_ipam_schema(self, db: InfrahubDatabase, test_client: InfrahubTestClient, ipam_schema) -> None: - config = Config(username="admin", password="infrahub", requester=test_client.async_request) - client = InfrahubClient(config=config) - response = await client.schema.load(schemas=[ipam_schema]) - assert not response.errors - - @pytest.fixture - async def default_ipam_namespace(self, client: InfrahubClient) -> InfrahubNode: - return await client.get(kind="IpamNamespace", name__value="default") - - async def test_node_create(self, client: InfrahubClient, init_db_base, load_builtin_schema, location_schema): - data = { - "name": {"value": "JFK1"}, - "description": {"value": "JFK Airport"}, - "type": {"value": "SITE"}, - } - node = InfrahubNode(client=client, schema=location_schema, data=data) - await node.save() - - assert node.id is not None - - async def test_node_delete_client( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - location_schema, - ): - data = { - "name": {"value": "ARN"}, - "description": {"value": "Arlanda Airport"}, - "type": {"value": "SITE"}, - } - node = InfrahubNode(client=client, schema=location_schema, data=data) - await node.save() - nodedb_pre_delete = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - - await node.delete() - nodedb_post_delete = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - assert nodedb_pre_delete - assert nodedb_pre_delete.id - assert not nodedb_post_delete - - async def test_node_delete_node( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - location_schema, - ): - obj = await Node.init(db=db, schema="CoreAccount") - await obj.new(db=db, name="delete-my-account", account_type="Git", password="delete-my-password") - await obj.save(db=db) - node_pre_delete = await client.get(kind="CoreAccount", name__value="delete-my-account") - assert node_pre_delete - assert node_pre_delete.id - await node_pre_delete.delete() - with pytest.raises(NodeNotFoundError): - await client.get(kind="CoreAccount", name__value="delete-my-account") - - async def test_node_create_with_relationships( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_blue: Node, - tag_red: Node, - repo01: Node, - gqlquery01: Node, - ): - data = { - "name": {"value": "rfile01"}, - "template_path": {"value": "mytemplate.j2"}, - "query": gqlquery01.id, - "repository": {"id": repo01.id}, - "tags": [tag_blue.id, tag_red.id], - } - - node = await client.create(kind="CoreTransformJinja2", data=data) - await node.save() - - assert node.id is not None - - nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - assert nodedb.name.value == node.name.value # type: ignore[attr-defined] - querydb = await nodedb.query.get_peer(db=db) - assert node.query.id == querydb.id # type: ignore[attr-defined] + async def initial_schema(self, default_branch: str, client: InfrahubClient, schema_base: SchemaRoot) -> None: + await client.schema.wait_until_converged(branch=default_branch) - async def test_node_update_payload_with_relationships( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_blue: Node, - tag_red: Node, - repo01: Node, - gqlquery01: Node, - ): - data = { - "name": "rfile10", - "template_path": "mytemplate.j2", - "query": gqlquery01.id, - "repository": repo01.id, - "tags": [tag_blue.id, tag_red.id], - } - schema = await client.schema.get(kind="CoreTransformJinja2", branch="main") - create_payload = client.schema.generate_payload_create( - schema=schema, data=data, source=repo01.id, is_protected=True + resp = await client.schema.load( + schemas=[schema_base.to_schema_dict()], branch=default_branch, wait_until_converged=True ) - obj = await client.create(kind="CoreTransformJinja2", branch="main", **create_payload) - await obj.save() - - assert obj.id is not None - nodedb = await client.get(kind="CoreTransformJinja2", id=str(obj.id)) - - input_data = nodedb._generate_input_data()["data"]["data"] - assert input_data["name"]["value"] == "rfile10" - # Validate that the source isn't a dictionary bit a reference to the repo - assert input_data["name"]["source"] == repo01.id + assert resp.errors == {} - async def test_node_create_with_properties( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_blue: Node, - tag_red: Node, - repo01: Node, - gqlquery01: Node, - first_account: Node, + async def test_node_create( + self, client: InfrahubClient, initial_schema: None, schema_manufacturer_base: NodeSchema ): + schema_manufacturer = NodeSchemaAPI(**schema_manufacturer_base.model_dump(exclude_unset=True)) data = { - "name": { - "value": "rfile02", - "is_protected": True, - "source": first_account.id, - "owner": first_account.id, - }, - "template_path": {"value": "mytemplate.j2"}, - "query": {"id": gqlquery01.id}, # "source": first_account.id, "owner": first_account.id}, - "repository": {"id": repo01.id}, # "source": first_account.id, "owner": first_account.id}, - "tags": [tag_blue.id, tag_red.id], + "name": {"value": "Fiat"}, + "description": {"value": "An italian brand"}, } - - node = await client.create(kind="CoreTransformJinja2", data=data) - await node.save() - - assert node.id is not None - - nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - assert nodedb.name.value == node.name.value # type: ignore[attr-defined] - assert nodedb.name.is_protected is True - - async def test_node_update( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_blue: Node, - tag_red: Node, - repo99: Node, - ): - node = await client.get(kind="CoreRepository", name__value="repo99") - assert node.id is not None - - node.name.value = "repo95" # type: ignore[attr-defined] - node.tags.add(tag_blue.id) # type: ignore[attr-defined] - node.tags.add(tag_red.id) # type: ignore[attr-defined] + node = InfrahubNode(client=client, schema=schema_manufacturer, data=data) await node.save() - nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - assert nodedb.name.value == "repo95" - tags = await nodedb.tags.get(db=db) - assert len(tags) == 2 - - async def test_node_update_2( - self, - db: InfrahubDatabase, - client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_green: Node, - tag_red: Node, - tag_blue: Node, - gqlquery02: Node, - repo99: Node, - ): - node = await client.get(kind="CoreGraphQLQuery", name__value="query02") assert node.id is not None - node.name.value = "query021" # type: ignore[attr-defined] - node.repository = repo99.id # type: ignore[attr-defined] - node.tags.add(tag_green.id) # type: ignore[attr-defined] - node.tags.remove(tag_red.id) # type: ignore[attr-defined] - await node.save() - - nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) - repodb = await nodedb.repository.get_peer(db=db) - assert repodb.id == repo99.id - - tags = await nodedb.tags.get(db=db) - assert sorted([tag.peer_id for tag in tags]) == sorted([tag_green.id, tag_blue.id]) - - async def test_node_update_3_idempotency( + async def test_node_delete( self, - db: InfrahubDatabase, + default_branch: str, client: InfrahubClient, - init_db_base, - load_builtin_schema, - tag_green: Node, - tag_red: Node, - tag_blue: Node, - gqlquery03: Node, - repo99: Node, + initial_schema: None, ): - node = await client.get(kind="CoreGraphQLQuery", name__value="query03") - assert node.id is not None - - updated_query = f"\n\n{node.query.value}" # type: ignore[attr-defined] - node.name.value = "query031" # type: ignore[attr-defined] - node.query.value = updated_query # type: ignore[attr-defined] - first_update = node._generate_input_data(exclude_unmodified=True) - await node.save() - nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + await self.create_manufacturers(client=client, branch=default_branch) - node = await client.get(kind="CoreGraphQLQuery", name__value="query031") + obj: InfrahubNode = client.store.get_by_hfid(key=f"{TESTING_MANUFACTURER}__Volkswagen") + await obj.delete() - node.name.value = "query031" # type: ignore[attr-defined] - node.query.value = updated_query # type: ignore[attr-defined] - - second_update = node._generate_input_data(exclude_unmodified=True) - - assert nodedb.query.value == updated_query # type: ignore[attr-defined] - assert "query" in first_update["data"]["data"] - assert "value" in first_update["data"]["data"]["query"] - assert first_update["variables"] - assert "query" not in second_update["data"]["data"] - assert not second_update["variables"] - - async def test_convert_node( - self, - db: InfrahubDatabase, - client: InfrahubClient, - location_schema, - init_db_base, - load_builtin_schema, - location_cdg: Node, - ): - data = await location_cdg.to_graphql(db=db) - node = InfrahubNode(client=client, schema=location_schema, data=data) - - # pylint: disable=no-member - assert node.name.value == "cdg01" # type: ignore[attr-defined] - - async def test_relationship_manager_errors_without_fetch(self, client: InfrahubClient, load_builtin_schema): - organization = await client.create("TestOrganization", name="organization-1") - await organization.save() - tag = await client.create("BuiltinTag", name="blurple") - await tag.save() - - with pytest.raises(UninitializedError, match=r"Must call fetch"): - organization.tags.add(tag) - - await organization.tags.fetch() - organization.tags.add(tag) - await organization.save() - - organization = await client.get("TestOrganization", name__value="organization-1") - assert [t.id for t in organization.tags.peers] == [tag.id] - - async def test_relationships_not_overwritten( - self, client: InfrahubClient, load_builtin_schema, schema_extension_01 - ): - await client.schema.load(schemas=[schema_extension_01]) - rack = await client.create("InfraRack", name="rack-1") - await rack.save() - tag = await client.create("BuiltinTag", name="blizzow") - # TODO: is it a bug that we need to save the object and fetch the tags before adding to a RelationshipManager now? - await tag.save() - await tag.racks.fetch() - tag.racks.add(rack) - await tag.save() - tag_2 = await client.create("BuiltinTag", name="blizzow2") - await tag_2.save() - - # the "rack" object has no link to the "tag" object here - # rack.tags.peers is empty - rack.name.value = "New Rack Name" - await rack.save() - - # assert that the above rack.save() did not overwrite the existing Rack-Tag relationship - refreshed_rack = await client.get("InfraRack", id=rack.id) - await refreshed_rack.tags.fetch() - assert [t.id for t in refreshed_rack.tags.peers] == [tag.id] - - # check that we can purposefully remove a tag - refreshed_rack.tags.remove(tag.id) - await refreshed_rack.save() - rack_without_tag = await client.get("InfraRack", id=rack.id) - await rack_without_tag.tags.fetch() - assert rack_without_tag.tags.peers == [] - - # check that we can purposefully add a tag - rack_without_tag.tags.add(tag_2) - await rack_without_tag.save() - refreshed_rack_with_tag = await client.get("InfraRack", id=rack.id) - await refreshed_rack_with_tag.tags.fetch() - assert [t.id for t in refreshed_rack_with_tag.tags.peers] == [tag_2.id] - - async def test_node_create_from_pool( - self, db: InfrahubDatabase, client: InfrahubClient, init_db_base, default_ipam_namespace, load_ipam_schema - ): - ip_prefix = await client.create(kind="IpamIPPrefix", prefix="192.0.2.0/24") - await ip_prefix.save() - - ip_pool = await client.create( - kind="CoreIPAddressPool", - name="Core loopbacks 1", - default_address_type="IpamIPAddress", - default_prefix_length=32, - ip_namespace=default_ipam_namespace, - resources=[ip_prefix], - ) - await ip_pool.save() - - devices = [] - for i in range(1, 5): - d = await client.create(kind="InfraDevice", name=f"core0{i}", primary_address=ip_pool) - await d.save() - devices.append(d) - - assert [str(device.primary_address.peer.address.value) for device in devices] == [ - "192.0.2.1/32", - "192.0.2.2/32", - "192.0.2.3/32", - "192.0.2.4/32", - ] - - async def test_node_update_from_pool( - self, db: InfrahubDatabase, client: InfrahubClient, init_db_base, default_ipam_namespace, load_ipam_schema - ): - starter_ip_address = await client.create(kind="IpamIPAddress", address="10.0.0.1/32") - await starter_ip_address.save() - - ip_prefix = await client.create(kind="IpamIPPrefix", prefix="192.168.0.0/24") - await ip_prefix.save() - - ip_pool = await client.create( - kind="CoreIPAddressPool", - name="Core loopbacks 2", - default_address_type="IpamIPAddress", - default_prefix_length=32, - ip_namespace=default_ipam_namespace, - resources=[ip_prefix], - ) - await ip_pool.save() - - device = await client.create(kind="InfraDevice", name="core05", primary_address=starter_ip_address) - await device.save() - - device.primary_address = ip_pool - await device.save() - - assert str(device.primary_address.peer.address.value) == "192.168.0.1/32" + with pytest.raises(NodeNotFoundError): + await client.get(kind=TESTING_MANUFACTURER, id=obj.id) + + # async def test_node_create_with_relationships( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_blue: Node, + # tag_red: Node, + # repo01: Node, + # gqlquery01: Node, + # ): + # data = { + # "name": {"value": "rfile01"}, + # "template_path": {"value": "mytemplate.j2"}, + # "query": gqlquery01.id, + # "repository": {"id": repo01.id}, + # "tags": [tag_blue.id, tag_red.id], + # } + + # node = await client.create(kind="CoreTransformJinja2", data=data) + # await node.save() + + # assert node.id is not None + + # nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + # assert nodedb.name.value == node.name.value # type: ignore[attr-defined] + # querydb = await nodedb.query.get_peer(db=db) + # assert node.query.id == querydb.id # type: ignore[attr-defined] + + # async def test_node_update_payload_with_relationships( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_blue: Node, + # tag_red: Node, + # repo01: Node, + # gqlquery01: Node, + # ): + # data = { + # "name": "rfile10", + # "template_path": "mytemplate.j2", + # "query": gqlquery01.id, + # "repository": repo01.id, + # "tags": [tag_blue.id, tag_red.id], + # } + # schema = await client.schema.get(kind="CoreTransformJinja2", branch="main") + # create_payload = client.schema.generate_payload_create( + # schema=schema, data=data, source=repo01.id, is_protected=True + # ) + # obj = await client.create(kind="CoreTransformJinja2", branch="main", **create_payload) + # await obj.save() + + # assert obj.id is not None + # nodedb = await client.get(kind="CoreTransformJinja2", id=str(obj.id)) + + # input_data = nodedb._generate_input_data()["data"]["data"] + # assert input_data["name"]["value"] == "rfile10" + # # Validate that the source isn't a dictionary bit a reference to the repo + # assert input_data["name"]["source"] == repo01.id + + # async def test_node_create_with_properties( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_blue: Node, + # tag_red: Node, + # repo01: Node, + # gqlquery01: Node, + # first_account: Node, + # ): + # data = { + # "name": { + # "value": "rfile02", + # "is_protected": True, + # "source": first_account.id, + # "owner": first_account.id, + # }, + # "template_path": {"value": "mytemplate.j2"}, + # "query": {"id": gqlquery01.id}, # "source": first_account.id, "owner": first_account.id}, + # "repository": {"id": repo01.id}, # "source": first_account.id, "owner": first_account.id}, + # "tags": [tag_blue.id, tag_red.id], + # } + + # node = await client.create(kind="CoreTransformJinja2", data=data) + # await node.save() + + # assert node.id is not None + + # nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + # assert nodedb.name.value == node.name.value # type: ignore[attr-defined] + # assert nodedb.name.is_protected is True + + # async def test_node_update( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_blue: Node, + # tag_red: Node, + # repo99: Node, + # ): + # node = await client.get(kind="CoreRepository", name__value="repo99") + # assert node.id is not None + + # node.name.value = "repo95" # type: ignore[attr-defined] + # node.tags.add(tag_blue.id) # type: ignore[attr-defined] + # node.tags.add(tag_red.id) # type: ignore[attr-defined] + # await node.save() + + # nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + # assert nodedb.name.value == "repo95" + # tags = await nodedb.tags.get(db=db) + # assert len(tags) == 2 + + # async def test_node_update_2( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_green: Node, + # tag_red: Node, + # tag_blue: Node, + # gqlquery02: Node, + # repo99: Node, + # ): + # node = await client.get(kind="CoreGraphQLQuery", name__value="query02") + # assert node.id is not None + + # node.name.value = "query021" # type: ignore[attr-defined] + # node.repository = repo99.id # type: ignore[attr-defined] + # node.tags.add(tag_green.id) # type: ignore[attr-defined] + # node.tags.remove(tag_red.id) # type: ignore[attr-defined] + # await node.save() + + # nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + # repodb = await nodedb.repository.get_peer(db=db) + # assert repodb.id == repo99.id + + # tags = await nodedb.tags.get(db=db) + # assert sorted([tag.peer_id for tag in tags]) == sorted([tag_green.id, tag_blue.id]) + + # async def test_node_update_3_idempotency( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # init_db_base, + # load_builtin_schema, + # tag_green: Node, + # tag_red: Node, + # tag_blue: Node, + # gqlquery03: Node, + # repo99: Node, + # ): + # node = await client.get(kind="CoreGraphQLQuery", name__value="query03") + # assert node.id is not None + + # updated_query = f"\n\n{node.query.value}" # type: ignore[attr-defined] + # node.name.value = "query031" # type: ignore[attr-defined] + # node.query.value = updated_query # type: ignore[attr-defined] + # first_update = node._generate_input_data(exclude_unmodified=True) + # await node.save() + # nodedb = await NodeManager.get_one(id=node.id, db=db, include_owner=True, include_source=True) + + # node = await client.get(kind="CoreGraphQLQuery", name__value="query031") + + # node.name.value = "query031" # type: ignore[attr-defined] + # node.query.value = updated_query # type: ignore[attr-defined] + + # second_update = node._generate_input_data(exclude_unmodified=True) + + # assert nodedb.query.value == updated_query # type: ignore[attr-defined] + # assert "query" in first_update["data"]["data"] + # assert "value" in first_update["data"]["data"]["query"] + # assert first_update["variables"] + # assert "query" not in second_update["data"]["data"] + # assert not second_update["variables"] + + # async def test_convert_node( + # self, + # db: InfrahubDatabase, + # client: InfrahubClient, + # location_schema, + # init_db_base, + # load_builtin_schema, + # location_cdg: Node, + # ): + # data = await location_cdg.to_graphql(db=db) + # node = InfrahubNode(client=client, schema=location_schema, data=data) + + # # pylint: disable=no-member + # assert node.name.value == "cdg01" # type: ignore[attr-defined] + + # async def test_relationship_manager_errors_without_fetch(self, client: InfrahubClient, load_builtin_schema): + # organization = await client.create("TestOrganization", name="organization-1") + # await organization.save() + # tag = await client.create("BuiltinTag", name="blurple") + # await tag.save() + + # with pytest.raises(UninitializedError, match=r"Must call fetch"): + # organization.tags.add(tag) + + # await organization.tags.fetch() + # organization.tags.add(tag) + # await organization.save() + + # organization = await client.get("TestOrganization", name__value="organization-1") + # assert [t.id for t in organization.tags.peers] == [tag.id] + + # async def test_relationships_not_overwritten( + # self, client: InfrahubClient, load_builtin_schema, schema_extension_01 + # ): + # await client.schema.load(schemas=[schema_extension_01]) + # rack = await client.create("InfraRack", name="rack-1") + # await rack.save() + # tag = await client.create("BuiltinTag", name="blizzow") + # # TODO: is it a bug that we need to save the object and fetch the tags before adding to a RelationshipManager now? + # await tag.save() + # await tag.racks.fetch() + # tag.racks.add(rack) + # await tag.save() + # tag_2 = await client.create("BuiltinTag", name="blizzow2") + # await tag_2.save() + + # # the "rack" object has no link to the "tag" object here + # # rack.tags.peers is empty + # rack.name.value = "New Rack Name" + # await rack.save() + + # # assert that the above rack.save() did not overwrite the existing Rack-Tag relationship + # refreshed_rack = await client.get("InfraRack", id=rack.id) + # await refreshed_rack.tags.fetch() + # assert [t.id for t in refreshed_rack.tags.peers] == [tag.id] + + # # check that we can purposefully remove a tag + # refreshed_rack.tags.remove(tag.id) + # await refreshed_rack.save() + # rack_without_tag = await client.get("InfraRack", id=rack.id) + # await rack_without_tag.tags.fetch() + # assert rack_without_tag.tags.peers == [] + + # # check that we can purposefully add a tag + # rack_without_tag.tags.add(tag_2) + # await rack_without_tag.save() + # refreshed_rack_with_tag = await client.get("InfraRack", id=rack.id) + # await refreshed_rack_with_tag.tags.fetch() + # assert [t.id for t in refreshed_rack_with_tag.tags.peers] == [tag_2.id] + + # async def test_node_create_from_pool( + # self, db: InfrahubDatabase, client: InfrahubClient, init_db_base, default_ipam_namespace, load_ipam_schema + # ): + # ip_prefix = await client.create(kind="IpamIPPrefix", prefix="192.0.2.0/24") + # await ip_prefix.save() + + # ip_pool = await client.create( + # kind="CoreIPAddressPool", + # name="Core loopbacks 1", + # default_address_type="IpamIPAddress", + # default_prefix_length=32, + # ip_namespace=default_ipam_namespace, + # resources=[ip_prefix], + # ) + # await ip_pool.save() + + # devices = [] + # for i in range(1, 5): + # d = await client.create(kind="InfraDevice", name=f"core0{i}", primary_address=ip_pool) + # await d.save() + # devices.append(d) + + # assert [str(device.primary_address.peer.address.value) for device in devices] == [ + # "192.0.2.1/32", + # "192.0.2.2/32", + # "192.0.2.3/32", + # "192.0.2.4/32", + # ] + + # async def test_node_update_from_pool( + # self, db: InfrahubDatabase, client: InfrahubClient, init_db_base, default_ipam_namespace, load_ipam_schema + # ): + # starter_ip_address = await client.create(kind="IpamIPAddress", address="10.0.0.1/32") + # await starter_ip_address.save() + + # ip_prefix = await client.create(kind="IpamIPPrefix", prefix="192.168.0.0/24") + # await ip_prefix.save() + + # ip_pool = await client.create( + # kind="CoreIPAddressPool", + # name="Core loopbacks 2", + # default_address_type="IpamIPAddress", + # default_prefix_length=32, + # ip_namespace=default_ipam_namespace, + # resources=[ip_prefix], + # ) + # await ip_pool.save() + + # device = await client.create(kind="InfraDevice", name="core05", primary_address=starter_ip_address) + # await device.save() + + # device.primary_address = ip_pool + # await device.save() + + # assert str(device.primary_address.peer.address.value) == "192.168.0.1/32" diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 1f193127..e288f189 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -3,7 +3,7 @@ from infrahub.server import app from infrahub_sdk import Config, InfrahubClient -from infrahub_sdk.schema import NodeSchema +from infrahub_sdk.schema import NodeSchemaAPI from .conftest import InfrahubTestClient @@ -28,14 +28,14 @@ async def test_schema_all(self, client, init_db_base): assert len(schema_nodes) == len(nodes) + len(generics) + len(profiles) assert "BuiltinTag" in schema_nodes - assert isinstance(schema_nodes["BuiltinTag"], NodeSchema) + assert isinstance(schema_nodes["BuiltinTag"], NodeSchemaAPI) async def test_schema_get(self, client, init_db_base): config = Config(username="admin", password="infrahub", requester=client.async_request) ifc = InfrahubClient(config=config) schema_node = await ifc.schema.get(kind="BuiltinTag") - assert isinstance(schema_node, NodeSchema) + assert isinstance(schema_node, NodeSchemaAPI) assert ifc.default_branch in ifc.schema.cache nodes = [node for node in core_models["nodes"] if node["namespace"] != "Internal"] generics = [node for node in core_models["generics"] if node["namespace"] != "Internal"] diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index c3473b2e..18279afa 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -11,7 +11,7 @@ from pytest_httpx import HTTPXMock from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync -from infrahub_sdk.schema import BranchSupportType, NodeSchema +from infrahub_sdk.schema import BranchSupportType, NodeSchema, NodeSchemaAPI from infrahub_sdk.utils import get_fixtures_dir # pylint: disable=redefined-outer-name,unused-argument @@ -125,7 +125,7 @@ def replace_annotations(parameters: Mapping[str, Parameter]) -> tuple[str, str]: @pytest.fixture -async def location_schema() -> NodeSchema: +async def location_schema() -> NodeSchemaAPI: data = { "name": "Location", "namespace": "Builtin", @@ -157,11 +157,11 @@ async def location_schema() -> NodeSchema: }, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def schema_with_hfid() -> dict[str, NodeSchema]: +async def schema_with_hfid() -> dict[str, NodeSchemaAPI]: data = { "location": { "name": "Location", @@ -222,11 +222,11 @@ async def schema_with_hfid() -> dict[str, NodeSchema]: ], }, } - return {k: NodeSchema(**v) for k, v in data.items()} # type: ignore + return {k: NodeSchema(**v).convert_api() for k, v in data.items()} # type: ignore @pytest.fixture -async def std_group_schema() -> NodeSchema: +async def std_group_schema() -> NodeSchemaAPI: data = { "name": "StandardGroup", "namespace": "Core", @@ -236,7 +236,7 @@ async def std_group_schema() -> NodeSchema: {"name": "description", "kind": "String", "optional": True}, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture @@ -506,17 +506,17 @@ async def location_data02(): @pytest.fixture -async def tag_schema() -> NodeSchema: +async def tag_schema() -> NodeSchemaAPI: data = { "name": "Tag", "namespace": "Builtin", "default_filter": "name__value", "attributes": [ - {"name": "name", "kind": "String", "unique": True}, - {"name": "description", "kind": "String", "optional": True}, + {"name": "name", "kind": "Text", "unique": True}, + {"name": "description", "kind": "Text", "optional": True}, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture @@ -694,7 +694,7 @@ async def tag_green_data(): @pytest.fixture -async def rfile_schema() -> NodeSchema: +async def rfile_schema() -> NodeSchemaAPI: data = { "name": "TransformJinja2", "namespace": "Core", @@ -730,11 +730,11 @@ async def rfile_schema() -> NodeSchema: }, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() @pytest.fixture -async def ipaddress_schema() -> NodeSchema: +async def ipaddress_schema() -> NodeSchemaAPI: data = { "name": "IPAddress", "namespace": "Infra", @@ -754,11 +754,11 @@ async def ipaddress_schema() -> NodeSchema: } ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def ipnetwork_schema() -> NodeSchema: +async def ipnetwork_schema() -> NodeSchemaAPI: data = { "name": "IPNetwork", "namespace": "Infra", @@ -778,11 +778,11 @@ async def ipnetwork_schema() -> NodeSchema: } ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def ipam_ipprefix_schema() -> NodeSchema: +async def ipam_ipprefix_schema() -> NodeSchemaAPI: data = { "name": "IPNetwork", "namespace": "Ipam", @@ -791,11 +791,11 @@ async def ipam_ipprefix_schema() -> NodeSchema: "order_by": ["prefix_value"], "inherit_from": ["BuiltinIPAddress"], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def simple_device_schema() -> NodeSchema: +async def simple_device_schema() -> NodeSchemaAPI: data = { "name": "Device", "namespace": "Infra", @@ -823,7 +823,7 @@ async def simple_device_schema() -> NodeSchema: }, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture @@ -897,7 +897,7 @@ async def ipam_ipprefix_data(): @pytest.fixture -async def ipaddress_pool_schema() -> NodeSchema: +async def ipaddress_pool_schema() -> NodeSchemaAPI: data = { "name": "IPAddressPool", "namespace": "Core", @@ -943,11 +943,11 @@ async def ipaddress_pool_schema() -> NodeSchema: }, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def ipprefix_pool_schema() -> NodeSchema: +async def ipprefix_pool_schema() -> NodeSchemaAPI: data = { "name": "IPPrefixPool", "namespace": "Core", @@ -1002,11 +1002,11 @@ async def ipprefix_pool_schema() -> NodeSchema: }, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture -async def address_schema() -> NodeSchema: +async def address_schema() -> NodeSchemaAPI: data = { "name": "Address", "namespace": "Infra", @@ -1021,7 +1021,7 @@ async def address_schema() -> NodeSchema: ], "relationships": [], } - return NodeSchema(**data) # type: ignore + return NodeSchemaAPI(**data) @pytest.fixture @@ -1065,7 +1065,7 @@ async def address_data(): @pytest.fixture -async def device_schema() -> NodeSchema: +async def device_schema() -> NodeSchemaAPI: data = { "name": "Device", "namespace": "Infra", @@ -1111,7 +1111,7 @@ async def device_schema() -> NodeSchema: {"name": "artifacts", "peer": "CoreArtifact", "optional": True, "cardinality": "many", "kind": "Generic"}, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture @@ -1272,7 +1272,7 @@ async def device_data(): @pytest.fixture -async def artifact_definition_schema() -> NodeSchema: +async def artifact_definition_schema() -> NodeSchemaAPI: data = { "name": "ArtifactDefinition", "namespace": "Core", @@ -1285,7 +1285,7 @@ async def artifact_definition_schema() -> NodeSchema: {"name": "artifact_name", "kind": "Text"}, ], } - return NodeSchema(**data) # type: ignore + return NodeSchema(**data).convert_api() # type: ignore @pytest.fixture diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index 543c8992..60ee0620 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -14,10 +14,10 @@ RelatedNodeBase, RelationshipManagerBase, ) +from infrahub_sdk.schema import GenericSchema, NodeSchemaAPI if TYPE_CHECKING: from infrahub_sdk.client import InfrahubClient, InfrahubClientSync - from infrahub_sdk.schema import GenericSchema # pylint: disable=no-member,too-many-lines # type: ignore[attr-defined] @@ -133,7 +133,7 @@ async def test_node_hfid(client, schema_with_hfid, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_data_user(client, location_schema, client_type): +async def test_init_node_data_user(client, location_schema: NodeSchemaAPI, client_type): data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -151,7 +151,7 @@ async def test_init_node_data_user(client, location_schema, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_data_user_with_relationships(client, location_schema, client_type): +async def test_init_node_data_user_with_relationships(client, location_schema: NodeSchemaAPI, client_type): data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -177,7 +177,7 @@ async def test_init_node_data_user_with_relationships(client, location_schema, c @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_data_graphql(client, location_schema, location_data01, client_type): +async def test_init_node_data_graphql(client, location_schema: NodeSchemaAPI, location_data01, client_type): if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema, data=location_data01) else: @@ -197,7 +197,7 @@ async def test_init_node_data_graphql(client, location_schema, location_data01, @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_no_filters(clients, location_schema, client_type): +async def test_query_data_no_filters(clients, location_schema: NodeSchemaAPI, client_type): if client_type == "standard": client: InfrahubClient = getattr(clients, client_type) # type: ignore[annotation-unchecked] node = InfrahubNode(client=client, schema=location_schema) @@ -297,7 +297,7 @@ async def test_query_data_no_filters(clients, location_schema, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_node(clients, location_schema, client_type): +async def test_query_data_node(clients, location_schema: NodeSchemaAPI, client_type): if client_type == "standard": client: InfrahubClient = getattr(clients, client_type) # type: ignore[annotation-unchecked] node = InfrahubNode(client=client, schema=location_schema) @@ -748,7 +748,7 @@ async def test_query_data_generic_fragment(clients, mock_schema_query_02, client @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_include(client, location_schema, client_type): +async def test_query_data_include(client, location_schema: NodeSchemaAPI, client_type): if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(include=["tags"]) @@ -870,7 +870,7 @@ async def test_query_data_include(client, location_schema, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_exclude(client, location_schema, client_type): +async def test_query_data_exclude(client, location_schema: NodeSchemaAPI, client_type): if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(exclude=["description", "primary_tag"]) @@ -929,7 +929,7 @@ async def test_query_data_exclude(client, location_schema, client_type): @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data(client, location_schema, client_type): +async def test_create_input_data(client, location_schema: NodeSchemaAPI, client_type): data = {"name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, "type": {"value": "SITE"}} if client_type == "standard": diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py index 0729daff..5c44e9c7 100644 --- a/tests/unit/sdk/test_schema.py +++ b/tests/unit/sdk/test_schema.py @@ -10,14 +10,16 @@ from infrahub_sdk.ctl.schema import display_schema_load_errors from infrahub_sdk.exceptions import SchemaNotFoundError, ValidationError from infrahub_sdk.schema import ( + InfrahubSchema, + InfrahubSchemaSync, + NodeSchemaAPI, +) +from infrahub_sdk.schema.repository import ( InfrahubCheckDefinitionConfig, InfrahubJinja2TransformConfig, InfrahubPythonTransformConfig, InfrahubRepositoryArtifactDefinitionConfig, InfrahubRepositoryConfig, - InfrahubSchema, - InfrahubSchemaSync, - NodeSchema, ) from tests.unit.sdk.conftest import BothClients @@ -59,7 +61,7 @@ async def test_fetch_schema(mock_schema_query_01, client_type): # pylint: disab "CoreGraphQLQuery", "CoreRepository", ] - assert isinstance(nodes["BuiltinTag"], NodeSchema) + assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI) @pytest.mark.parametrize("client_type", client_types)