Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/152.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add 'schema_hash' parameter to client.schema.all to only optionally refresh the schema if the provided hash differs from what the client has already cached.
81 changes: 55 additions & 26 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
from collections import defaultdict
from collections.abc import MutableMapping
from enum import Enum
from time import sleep
Expand All @@ -22,6 +21,7 @@
from .main import (
AttributeSchema,
AttributeSchemaAPI,
BranchSchema,
BranchSupportType,
GenericSchema,
GenericSchemaAPI,
Expand Down Expand Up @@ -169,7 +169,7 @@
class InfrahubSchema(InfrahubSchemaBase):
def __init__(self, client: InfrahubClient):
self.client = client
self.cache: dict = defaultdict(lambda: dict)
self.cache: dict[str, BranchSchema] = {}

async def get(
self,
Expand All @@ -183,23 +183,27 @@
kind_str = self._get_schema_name(schema=kind)

if refresh:
self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)

Check warning on line 186 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L186

Added line #L186 was not covered by tests

if branch in self.cache and kind_str in self.cache[branch]:
return self.cache[branch][kind_str]
if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]

# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)

if branch in self.cache and kind_str in self.cache[branch]:
return self.cache[branch][kind_str]
if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]

raise SchemaNotFoundError(identifier=kind_str)

async def all(
self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
self,
branch: str | None = None,
refresh: bool = False,
namespaces: list[str] | None = None,
schema_hash: str | None = None,
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Retrieve the entire schema for a given branch.

Expand All @@ -209,15 +213,19 @@
Args:
branch (str, optional): Name of the branch to query. Defaults to default_branch.
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch = branch or self.client.default_branch
if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
refresh = False

if refresh or branch not in self.cache:
self.cache[branch] = await self.fetch(branch=branch, namespaces=namespaces)
self.cache[branch] = await self._fetch(branch=branch, namespaces=namespaces)

return self.cache[branch]
return self.cache[branch].nodes

async def load(
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
Expand Down Expand Up @@ -392,11 +400,17 @@

Args:
branch (str): Name of the branch to fetch the schema for.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
return branch_schema.nodes

async def _fetch(
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
Expand Down Expand Up @@ -425,16 +439,22 @@
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template

return nodes
schema_hash = data.get("main", "")

return BranchSchema(hash=schema_hash, nodes=nodes)


class InfrahubSchemaSync(InfrahubSchemaBase):
def __init__(self, client: InfrahubClientSync):
self.client = client
self.cache: dict = defaultdict(lambda: dict)
self.cache: dict[str, BranchSchema] = {}

def all(
self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
self,
branch: str | None = None,
refresh: bool = False,
namespaces: list[str] | None = None,
schema_hash: str | None = None,
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Retrieve the entire schema for a given branch.

Expand All @@ -444,15 +464,19 @@
Args:
branch (str, optional): Name of the branch to query. Defaults to default_branch.
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch = branch or self.client.default_branch
if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
refresh = False

if refresh or branch not in self.cache:
self.cache[branch] = self.fetch(branch=branch, namespaces=namespaces)
self.cache[branch] = self._fetch(branch=branch, namespaces=namespaces)

return self.cache[branch]
return self.cache[branch].nodes

def get(
self,
Expand All @@ -466,18 +490,18 @@
kind_str = self._get_schema_name(schema=kind)

if refresh:
self.cache[branch] = self.fetch(branch=branch)
self.cache[branch] = self._fetch(branch=branch)

Check warning on line 493 in infrahub_sdk/schema/__init__.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema/__init__.py#L493

Added line #L493 was not covered by tests

if branch in self.cache and kind_str in self.cache[branch]:
return self.cache[branch][kind_str]
if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]

# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
self.cache[branch] = self.fetch(branch=branch, timeout=timeout)
self.cache[branch] = self._fetch(branch=branch, timeout=timeout)

if branch in self.cache and kind_str in self.cache[branch]:
return self.cache[branch][kind_str]
if branch in self.cache and kind_str in self.cache[branch].nodes:
return self.cache[branch].nodes[kind_str]

raise SchemaNotFoundError(identifier=kind_str)

Expand Down Expand Up @@ -600,17 +624,20 @@

Args:
branch (str): Name of the branch to fetch the schema for.
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.

Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
return branch_schema.nodes

def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
query_params = urlencode(url_parts)
url = f"{self.client.address}/api/schema?{query_params}"

response = self.client._get(url=url, timeout=timeout)
response.raise_for_status()

Expand All @@ -633,7 +660,9 @@
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template

return nodes
schema_hash = data.get("main", "")

return BranchSchema(hash=schema_hash, nodes=nodes)

def load(
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
Expand Down
8 changes: 8 additions & 0 deletions infrahub_sdk/schema/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from collections.abc import MutableMapping
from enum import Enum
from typing import TYPE_CHECKING, Any, Union

Expand Down Expand Up @@ -348,3 +349,10 @@ class SchemaRootAPI(BaseModel):
nodes: list[NodeSchemaAPI] = Field(default_factory=list)
profiles: list[ProfileSchemaAPI] = Field(default_factory=list)
templates: list[TemplateSchemaAPI] = Field(default_factory=list)


class BranchSchema(BaseModel):
hash: str = Field(...)
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
default_factory=dict
)
1 change: 1 addition & 0 deletions tests/fixtures/schema_01.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"main": "c0272bc24cd943f21cf30affda06b12d",
"nodes": [
{
"name": "GraphQLQuery",
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/sdk/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ async def test_fetch_schema(mock_schema_query_01, client_type):
assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)


@pytest.mark.parametrize("client_type", client_types)
async def test_fetch_schema_conditional_refresh(mock_schema_query_01: HTTPXMock, client_type: str) -> None:
"""Verify that only one schema request is sent if we request to update the schema but already have the correct hash"""
if client_type == "standard":
client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True))
nodes = await client.schema.all(branch="main")
schema_hash = client.schema.cache["main"].hash
assert schema_hash
nodes = await client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)
else:
client = InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True))
nodes = client.schema.all(branch="main")
schema_hash = client.schema.cache["main"].hash
assert schema_hash
nodes = client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)

assert len(nodes) == 4
assert sorted(nodes.keys()) == [
"BuiltinLocation",
"BuiltinTag",
"CoreGraphQLQuery",
"CoreRepository",
]
assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)
assert len(mock_schema_query_01.get_requests()) == 1


@pytest.mark.parametrize("client_type", client_types)
async def test_schema_data_validation(rfile_schema, client_type):
if client_type == "standard":
Expand Down