From e77546783fd4d0f49158165186483b32716e871d Mon Sep 17 00:00:00 2001 From: Patrick Ogenstad Date: Thu, 20 Mar 2025 10:25:42 +0100 Subject: [PATCH] Add parameter to only optionally refresh the schema hash on schema.all --- changelog/152.added.md | 1 + infrahub_sdk/schema/__init__.py | 81 ++++++++++++++++++++++----------- infrahub_sdk/schema/main.py | 8 ++++ tests/fixtures/schema_01.json | 1 + tests/unit/sdk/test_schema.py | 27 +++++++++++ 5 files changed, 92 insertions(+), 26 deletions(-) create mode 100644 changelog/152.added.md diff --git a/changelog/152.added.md b/changelog/152.added.md new file mode 100644 index 00000000..fa1277fe --- /dev/null +++ b/changelog/152.added.md @@ -0,0 +1 @@ +Add 'schema_hash' parameter to client.schema.all to only optionally refresh the schema if the provided hash differs from what the client has already cached. diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 9b23fe49..080d7237 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -from collections import defaultdict from collections.abc import MutableMapping from enum import Enum from time import sleep @@ -22,6 +21,7 @@ from .main import ( AttributeSchema, AttributeSchemaAPI, + BranchSchema, BranchSupportType, GenericSchema, GenericSchemaAPI, @@ -169,7 +169,7 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str: class InfrahubSchema(InfrahubSchemaBase): def __init__(self, client: InfrahubClient): self.client = client - self.cache: dict = defaultdict(lambda: dict) + self.cache: dict[str, BranchSchema] = {} async def get( self, @@ -183,23 +183,27 @@ async def get( kind_str = self._get_schema_name(schema=kind) if refresh: - self.cache[branch] = await self.fetch(branch=branch, timeout=timeout) + self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) - if branch in self.cache and kind_str in self.cache[branch]: - return self.cache[branch][kind_str] + if branch in self.cache and kind_str in self.cache[branch].nodes: + return self.cache[branch].nodes[kind_str] # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = await self.fetch(branch=branch, timeout=timeout) + self.cache[branch] = await self._fetch(branch=branch, timeout=timeout) - if branch in self.cache and kind_str in self.cache[branch]: - return self.cache[branch][kind_str] + if branch in self.cache and kind_str in self.cache[branch].nodes: + return self.cache[branch].nodes[kind_str] raise SchemaNotFoundError(identifier=kind_str) async def all( - self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None + self, + branch: str | None = None, + refresh: bool = False, + namespaces: list[str] | None = None, + schema_hash: str | None = None, ) -> MutableMapping[str, MainSchemaTypesAPI]: """Retrieve the entire schema for a given branch. @@ -209,15 +213,19 @@ async def all( Args: branch (str, optional): Name of the branch to query. Defaults to default_branch. refresh (bool, optional): Force a refresh of the schema. Defaults to False. + schema_hash (str, optional): Only refresh if the current schema doesn't match this hash. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ branch = branch or self.client.default_branch + if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash: + refresh = False + if refresh or branch not in self.cache: - self.cache[branch] = await self.fetch(branch=branch, namespaces=namespaces) + self.cache[branch] = await self._fetch(branch=branch, namespaces=namespaces) - return self.cache[branch] + return self.cache[branch].nodes async def load( self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False @@ -392,11 +400,17 @@ async def fetch( Args: branch (str): Name of the branch to fetch the schema for. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. + timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ + branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) + return branch_schema.nodes + + async def _fetch( + self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None + ) -> BranchSchema: url_parts = [("branch", branch)] if namespaces: url_parts.extend([("namespaces", ns) for ns in namespaces]) @@ -425,16 +439,22 @@ async def fetch( template = TemplateSchemaAPI(**template_schema) nodes[template.kind] = template - return nodes + schema_hash = data.get("main", "") + + return BranchSchema(hash=schema_hash, nodes=nodes) class InfrahubSchemaSync(InfrahubSchemaBase): def __init__(self, client: InfrahubClientSync): self.client = client - self.cache: dict = defaultdict(lambda: dict) + self.cache: dict[str, BranchSchema] = {} def all( - self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None + self, + branch: str | None = None, + refresh: bool = False, + namespaces: list[str] | None = None, + schema_hash: str | None = None, ) -> MutableMapping[str, MainSchemaTypesAPI]: """Retrieve the entire schema for a given branch. @@ -444,15 +464,19 @@ def all( Args: branch (str, optional): Name of the branch to query. Defaults to default_branch. refresh (bool, optional): Force a refresh of the schema. Defaults to False. + schema_hash (str, optional): Only refresh if the current schema doesn't match this hash. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ branch = branch or self.client.default_branch + if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash: + refresh = False + if refresh or branch not in self.cache: - self.cache[branch] = self.fetch(branch=branch, namespaces=namespaces) + self.cache[branch] = self._fetch(branch=branch, namespaces=namespaces) - return self.cache[branch] + return self.cache[branch].nodes def get( self, @@ -466,18 +490,18 @@ def get( kind_str = self._get_schema_name(schema=kind) if refresh: - self.cache[branch] = self.fetch(branch=branch) + self.cache[branch] = self._fetch(branch=branch) - if branch in self.cache and kind_str in self.cache[branch]: - return self.cache[branch][kind_str] + if branch in self.cache and kind_str in self.cache[branch].nodes: + return self.cache[branch].nodes[kind_str] # Fetching the latest schema from the server if we didn't fetch it earlier # because we coulnd't find the object on the local cache if not refresh: - self.cache[branch] = self.fetch(branch=branch, timeout=timeout) + self.cache[branch] = self._fetch(branch=branch, timeout=timeout) - if branch in self.cache and kind_str in self.cache[branch]: - return self.cache[branch][kind_str] + if branch in self.cache and kind_str in self.cache[branch].nodes: + return self.cache[branch].nodes[kind_str] raise SchemaNotFoundError(identifier=kind_str) @@ -600,17 +624,20 @@ def fetch( Args: branch (str): Name of the branch to fetch the schema for. - timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds. + timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. Returns: dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind """ + branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout) + return branch_schema.nodes + + def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema: url_parts = [("branch", branch)] if namespaces: url_parts.extend([("namespaces", ns) for ns in namespaces]) query_params = urlencode(url_parts) url = f"{self.client.address}/api/schema?{query_params}" - response = self.client._get(url=url, timeout=timeout) response.raise_for_status() @@ -633,7 +660,9 @@ def fetch( template = TemplateSchemaAPI(**template_schema) nodes[template.kind] = template - return nodes + schema_hash = data.get("main", "") + + return BranchSchema(hash=schema_hash, nodes=nodes) def load( self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py index 57aaa890..af5556b3 100644 --- a/infrahub_sdk/schema/main.py +++ b/infrahub_sdk/schema/main.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from collections.abc import MutableMapping from enum import Enum from typing import TYPE_CHECKING, Any, Union @@ -348,3 +349,10 @@ class SchemaRootAPI(BaseModel): nodes: list[NodeSchemaAPI] = Field(default_factory=list) profiles: list[ProfileSchemaAPI] = Field(default_factory=list) templates: list[TemplateSchemaAPI] = Field(default_factory=list) + + +class BranchSchema(BaseModel): + hash: str = Field(...) + nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field( + default_factory=dict + ) diff --git a/tests/fixtures/schema_01.json b/tests/fixtures/schema_01.json index bf6f016e..344ebeab 100644 --- a/tests/fixtures/schema_01.json +++ b/tests/fixtures/schema_01.json @@ -1,4 +1,5 @@ { + "main": "c0272bc24cd943f21cf30affda06b12d", "nodes": [ { "name": "GraphQLQuery", diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py index fcffb1c0..ee07a137 100644 --- a/tests/unit/sdk/test_schema.py +++ b/tests/unit/sdk/test_schema.py @@ -64,6 +64,33 @@ async def test_fetch_schema(mock_schema_query_01, client_type): assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI) +@pytest.mark.parametrize("client_type", client_types) +async def test_fetch_schema_conditional_refresh(mock_schema_query_01: HTTPXMock, client_type: str) -> None: + """Verify that only one schema request is sent if we request to update the schema but already have the correct hash""" + if client_type == "standard": + client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True)) + nodes = await client.schema.all(branch="main") + schema_hash = client.schema.cache["main"].hash + assert schema_hash + nodes = await client.schema.all(branch="main", refresh=True, schema_hash=schema_hash) + else: + client = InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True)) + nodes = client.schema.all(branch="main") + schema_hash = client.schema.cache["main"].hash + assert schema_hash + nodes = client.schema.all(branch="main", refresh=True, schema_hash=schema_hash) + + assert len(nodes) == 4 + assert sorted(nodes.keys()) == [ + "BuiltinLocation", + "BuiltinTag", + "CoreGraphQLQuery", + "CoreRepository", + ] + assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI) + assert len(mock_schema_query_01.get_requests()) == 1 + + @pytest.mark.parametrize("client_type", client_types) async def test_schema_data_validation(rfile_schema, client_type): if client_type == "standard":