diff --git a/changelog/+schema-fetch.added.md b/changelog/+schema-fetch.added.md new file mode 100644 index 00000000..52399a47 --- /dev/null +++ b/changelog/+schema-fetch.added.md @@ -0,0 +1 @@ +By default, schema.fetch will now populate the cache (this behavior can be changed with `populate_cache`) diff --git a/changelog/+schema-set-cache.added.md b/changelog/+schema-set-cache.added.md new file mode 100644 index 00000000..1dcb4ce4 --- /dev/null +++ b/changelog/+schema-set-cache.added.md @@ -0,0 +1 @@ +Add method `client.schema.set_cache()` to populate the cache manually (primarily for unit testing) \ No newline at end of file diff --git a/changelog/+schema-timeout.deprecated.md b/changelog/+schema-timeout.deprecated.md new file mode 100644 index 00000000..788ba67c --- /dev/null +++ b/changelog/+schema-timeout.deprecated.md @@ -0,0 +1 @@ +The 'timeout' parameter while creating a node or fetching the schema has been deprecated. the default_timeout will be used instead. \ No newline at end of file diff --git a/infrahub_sdk/checks.py b/infrahub_sdk/checks.py index ad692537..79511ccb 100644 --- a/infrahub_sdk/checks.py +++ b/infrahub_sdk/checks.py @@ -83,7 +83,7 @@ def client(self, value: InfrahubClient) -> None: async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: Any) -> InfrahubCheck: """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" warnings.warn( - "InfrahubCheck.init has been deprecated and will be removed in the version in Infrahub SDK 2.0.0", + "InfrahubCheck.init has been deprecated and will be removed in version 2.0.0 of the Infrahub Python SDK", DeprecationWarning, stacklevel=1, ) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 5ae97777..be1cfab9 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -2,6 +2,7 @@ import asyncio import json +import warnings from collections.abc import MutableMapping from enum import Enum from time import sleep @@ -90,6 +91,13 @@ class EnumMutation(str, Enum): class InfrahubSchemaBase: + client: InfrahubClient | InfrahubClientSync + cache: dict[str, BranchSchema] + + def __init__(self, client: InfrahubClient | InfrahubClientSync): + self.client = client + self.cache = {} + def validate(self, data: dict[str, Any]) -> None: SchemaRoot(**data) @@ -102,6 +110,23 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) - message=f"{key} is not a valid value for {identifier}", ) + def set_cache(self, schema: dict[str, Any] | SchemaRootAPI | BranchSchema, branch: str | None = None) -> None: + """ + Set the cache manually (primarily for unit testing) + + Args: + schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format + branch: The name of the branch to set the cache for. + """ + branch = branch or self.client.default_branch + + if isinstance(schema, SchemaRootAPI): + schema = BranchSchema.from_schema_root_api(data=schema) + elif isinstance(schema, dict): + schema = BranchSchema.from_api_response(data=schema) + + self.cache[branch] = schema + def generate_payload_create( self, schema: MainSchemaTypesAPI, @@ -187,11 +212,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp return data + @staticmethod + def _deprecated_schema_timeout() -> None: + warnings.warn( + "The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. " + "Use client.default_timeout instead.", + DeprecationWarning, + stacklevel=2, + ) + class InfrahubSchema(InfrahubSchemaBase): - def __init__(self, client: InfrahubClient): - self.client = client - self.cache: dict[str, BranchSchema] = {} + client: InfrahubClient async def get( self, @@ -204,8 +236,11 @@ async def get( kind_str = self._get_schema_name(schema=kind) + if timeout: + self._deprecated_schema_timeout() + if refresh: - self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) + self.cache[branch] = await self._fetch(branch=branch) if branch in self.cache and kind_str in self.cache[branch].nodes: return self.cache[branch].nodes[kind_str] @@ -213,7 +248,7 @@ async def get( # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) + self.cache[branch] = await self._fetch(branch=branch) if branch in self.cache and kind_str in self.cache[branch].nodes: return self.cache[branch].nodes[kind_str] @@ -416,59 +451,45 @@ async def add_dropdown_option( ) async def fetch( - self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None + self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True ) -> MutableMapping[str, MainSchemaTypesAPI]: """Fetch the schema from the server for a given branch. Args: - branch (str): Name of the branch to fetch the schema for. - timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. + branch: Name of the branch to fetch the schema for. + timeout: Overrides default timeout used when querying the schema. deprecated. + populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ - branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) + + if timeout: + self._deprecated_schema_timeout() + + branch_schema = await self._fetch(branch=branch, namespaces=namespaces) + + if populate_cache: + self.cache[branch] = branch_schema + return branch_schema.nodes - async def _fetch( - self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None - ) -> BranchSchema: + async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema: url_parts = [("branch", branch)] if namespaces: url_parts.extend([("namespaces", ns) for ns in namespaces]) query_params = urlencode(url_parts) url = f"{self.client.address}/api/schema?{query_params}" - response = await self.client._get(url=url, timeout=timeout) + response = await self.client._get(url=url) data = self._parse_schema_response(response=response, branch=branch) - nodes: MutableMapping[str, MainSchemaTypesAPI] = {} - for node_schema in data.get("nodes", []): - node = NodeSchemaAPI(**node_schema) - nodes[node.kind] = node - - for generic_schema in data.get("generics", []): - generic = GenericSchemaAPI(**generic_schema) - nodes[generic.kind] = generic - - for profile_schema in data.get("profiles", []): - profile = ProfileSchemaAPI(**profile_schema) - nodes[profile.kind] = profile - - for template_schema in data.get("templates", []): - template = TemplateSchemaAPI(**template_schema) - nodes[template.kind] = template - - schema_hash = data.get("main", "") - - return BranchSchema(hash=schema_hash, nodes=nodes) + return BranchSchema.from_api_response(data=data) class InfrahubSchemaSync(InfrahubSchemaBase): - def __init__(self, client: InfrahubClientSync): - self.client = client - self.cache: dict[str, BranchSchema] = {} + client: InfrahubClientSync def all( self, @@ -506,10 +527,25 @@ def get( refresh: bool = False, timeout: int | None = None, ) -> MainSchemaTypesAPI: + """ + Retrieve a specific schema object from the server. + + Args: + kind: The kind of schema object to retrieve. + branch: The branch to retrieve the schema from. + refresh: Whether to refresh the schema. + timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). + + Returns: + MainSchemaTypes: The schema object. + """ branch = branch or self.client.default_branch kind_str = self._get_schema_name(schema=kind) + if timeout: + self._deprecated_schema_timeout() + if refresh: self.cache[branch] = self._fetch(branch=branch) @@ -519,7 +555,7 @@ def get( # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = self._fetch(branch=branch, timeout=timeout) + self.cache[branch] = self._fetch(branch=branch) if branch in self.cache and kind_str in self.cache[branch].nodes: return self.cache[branch].nodes[kind_str] @@ -639,49 +675,39 @@ def add_dropdown_option( ) def fetch( - self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None + self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True ) -> MutableMapping[str, MainSchemaTypesAPI]: """Fetch the schema from the server for a given branch. Args: - branch (str): Name of the branch to fetch the schema for. - timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. + branch: Name of the branch to fetch the schema for. + timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). + populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ - branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) + if timeout: + self._deprecated_schema_timeout() + + branch_schema = self._fetch(branch=branch, namespaces=namespaces) + + if populate_cache: + self.cache[branch] = branch_schema + return branch_schema.nodes - def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema: + def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema: url_parts = [("branch", branch)] if namespaces: url_parts.extend([("namespaces", ns) for ns in namespaces]) query_params = urlencode(url_parts) url = f"{self.client.address}/api/schema?{query_params}" - response = self.client._get(url=url, timeout=timeout) - data = self._parse_schema_response(response=response, branch=branch) + response = self.client._get(url=url) - nodes: MutableMapping[str, MainSchemaTypesAPI] = {} - for node_schema in data.get("nodes", []): - node = NodeSchemaAPI(**node_schema) - nodes[node.kind] = node - - for generic_schema in data.get("generics", []): - generic = GenericSchemaAPI(**generic_schema) - nodes[generic.kind] = generic - - for profile_schema in data.get("profiles", []): - profile = ProfileSchemaAPI(**profile_schema) - nodes[profile.kind] = profile - - for template_schema in data.get("templates", []): - template = TemplateSchemaAPI(**template_schema) - nodes[template.kind] = template - - schema_hash = data.get("main", "") + data = self._parse_schema_response(response=response, branch=branch) - return BranchSchema(hash=schema_hash, nodes=nodes) + return BranchSchema.from_api_response(data=data) def load( self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py index af5556b3..8776c6d6 100644 --- a/infrahub_sdk/schema/main.py +++ b/infrahub_sdk/schema/main.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Union from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Self if TYPE_CHECKING: from ..node import InfrahubNode, InfrahubNodeSync @@ -344,7 +345,7 @@ def to_schema_dict(self) -> dict[str, Any]: class SchemaRootAPI(BaseModel): model_config = ConfigDict(use_enum_values=True) - version: str + main: str | None = None generics: list[GenericSchemaAPI] = Field(default_factory=list) nodes: list[NodeSchemaAPI] = Field(default_factory=list) profiles: list[ProfileSchemaAPI] = Field(default_factory=list) @@ -356,3 +357,32 @@ class BranchSchema(BaseModel): nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field( default_factory=dict ) + + @classmethod + def from_api_response(cls, data: MutableMapping[str, Any]) -> Self: + """ + Convert an API response from /api/schema into a BranchSchema object. + """ + return cls.from_schema_root_api(data=SchemaRootAPI(**data)) + + @classmethod + def from_schema_root_api(cls, data: SchemaRootAPI) -> Self: + """ + Convert a SchemaRootAPI object to a BranchSchema object. + """ + nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = {} + for node in data.nodes: + nodes[node.kind] = node + + for generic in data.generics: + nodes[generic.kind] = generic + + for profile in data.profiles: + nodes[profile.kind] = profile + + for template in data.templates: + nodes[template.kind] = template + + schema_hash = data.main or "" + + return cls(hash=schema_hash, nodes=nodes) diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index fdd63de1..432b9897 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -1795,38 +1795,58 @@ async def mock_query_repository_page2_2( @pytest.fixture -async def mock_schema_query_01(httpx_mock: HTTPXMock) -> HTTPXMock: +async def schema_query_01_data() -> dict: response_text = (get_fixtures_dir() / "schema_01.json").read_text(encoding="UTF-8") + return ujson.loads(response_text) + +@pytest.fixture +async def schema_query_02_data() -> dict: + response_text = (get_fixtures_dir() / "schema_02.json").read_text(encoding="UTF-8") + return ujson.loads(response_text) + + +@pytest.fixture +async def schema_query_04_data() -> dict: + response_text = (get_fixtures_dir() / "schema_04.json").read_text(encoding="UTF-8") + return ujson.loads(response_text) + + +@pytest.fixture +async def schema_query_05_data() -> dict: + response_text = (get_fixtures_dir() / "schema_05.json").read_text(encoding="UTF-8") + return ujson.loads(response_text) + + +@pytest.fixture +async def mock_schema_query_01(httpx_mock: HTTPXMock, schema_query_01_data: dict) -> HTTPXMock: httpx_mock.add_response( method="GET", url="http://mock/api/schema?branch=main", - json=ujson.loads(response_text), + json=schema_query_01_data, is_reusable=True, ) return httpx_mock @pytest.fixture -async def mock_schema_query_02(httpx_mock: HTTPXMock) -> HTTPXMock: - response_text = (get_fixtures_dir() / "schema_02.json").read_text(encoding="UTF-8") +async def mock_schema_query_02(httpx_mock: HTTPXMock, schema_query_02_data: dict) -> HTTPXMock: httpx_mock.add_response( method="GET", url=re.compile(r"^http://mock/api/schema\?branch=(main|cr1234)"), - json=ujson.loads(response_text), + json=schema_query_02_data, is_reusable=True, ) return httpx_mock @pytest.fixture -async def mock_schema_query_05(httpx_mock: HTTPXMock) -> HTTPXMock: - response_text = (get_fixtures_dir() / "schema_05.json").read_text(encoding="UTF-8") - +async def mock_schema_query_05(httpx_mock: HTTPXMock, schema_query_05_data: dict) -> HTTPXMock: httpx_mock.add_response( method="GET", url="http://mock/api/schema?branch=main", - json=ujson.loads(response_text), + json=schema_query_05_data, + is_reusable=True, ) return httpx_mock @@ -1933,13 +1953,11 @@ async def mock_rest_api_artifact_fetch(httpx_mock: HTTPXMock) -> HTTPXMock: @pytest.fixture -async def mock_rest_api_artifact_generate(httpx_mock: HTTPXMock) -> HTTPXMock: - schema_response = (get_fixtures_dir() / "schema_04.json").read_text(encoding="UTF-8") - +async def mock_rest_api_artifact_generate(httpx_mock: HTTPXMock, schema_query_04_data: dict) -> HTTPXMock: httpx_mock.add_response( method="GET", url="http://mock/api/schema?branch=main", - json=ujson.loads(schema_response), + json=schema_query_04_data, is_reusable=True, ) diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py index ee07a137..32123226 100644 --- a/tests/unit/sdk/test_schema.py +++ b/tests/unit/sdk/test_schema.py @@ -9,11 +9,7 @@ from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync from infrahub_sdk.ctl.schema import display_schema_load_errors from infrahub_sdk.exceptions import SchemaNotFoundError, ValidationError -from infrahub_sdk.schema import ( - InfrahubSchema, - InfrahubSchemaSync, - NodeSchemaAPI, -) +from infrahub_sdk.schema import BranchSchema, InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI from infrahub_sdk.schema.repository import ( InfrahubCheckDefinitionConfig, InfrahubJinja2TransformConfig, @@ -222,6 +218,34 @@ async def test_schema_wait_happy_path(clients: BothClients, client_type: list[st assert len(httpx_mock.get_requests()) == 2 +@pytest.mark.parametrize("client_type", client_types) +async def test_schema_set_cache_dict(clients: BothClients, client_type: list[str], schema_query_01_data: dict) -> None: + if client_type == "standard": + client = clients.standard + else: + client = clients.sync + + client.schema.set_cache(schema_query_01_data, branch="branch1") + assert "branch1" in client.schema.cache + assert client.schema.cache["branch1"].nodes["CoreGraphQLQuery"] + + +@pytest.mark.parametrize("client_type", client_types) +async def test_schema_set_cache_branch_schema( + clients: BothClients, client_type: list[str], schema_query_01_data: dict +) -> None: + if client_type == "standard": + client = clients.standard + else: + client = clients.sync + + schema = BranchSchema.from_api_response(schema_query_01_data) + + client.schema.set_cache(schema) + assert "main" in client.schema.cache + assert client.schema.cache["main"].nodes["CoreGraphQLQuery"] + + async def test_infrahub_repository_config_getters(): repo_config = InfrahubRepositoryConfig( jinja2_transforms=[