Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
if TYPE_CHECKING:
from types import TracebackType

from .context import RequestContext


SchemaType = TypeVar("SchemaType", bound=CoreNode)
SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync)
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
self.identifier = self.config.identifier
self.group_context: InfrahubGroupContext | InfrahubGroupContextSync
self._initialize()
self._request_context: RequestContext | None = None

def _initialize(self) -> None:
"""Sets the properties for each version of the client"""
Expand All @@ -153,6 +156,14 @@ def _echo(self, url: str, query: str, variables: dict | None = None) -> None:
if variables:
print(f"VARIABLES:\n{ujson.dumps(variables, indent=4)}\n")

@property
def request_context(self) -> RequestContext | None:
return self._request_context

@request_context.setter
def request_context(self, request_context: RequestContext) -> None:
self._request_context = request_context

def start_tracking(
self,
identifier: str | None = None,
Expand Down
13 changes: 13 additions & 0 deletions infrahub_sdk/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

Check warning on line 1 in infrahub_sdk/context.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/context.py#L1

Added line #L1 was not covered by tests

from pydantic import BaseModel, Field

Check warning on line 3 in infrahub_sdk/context.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/context.py#L3

Added line #L3 was not covered by tests


class ContextAccount(BaseModel):
id: str = Field(..., description="The ID of the account")

Check warning on line 7 in infrahub_sdk/context.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/context.py#L6-L7

Added lines #L6 - L7 were not covered by tests


class RequestContext(BaseModel):

Check warning on line 10 in infrahub_sdk/context.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/context.py#L10

Added line #L10 was not covered by tests
"""The context can be used to override settings such as the account within mutations."""

account: ContextAccount | None = Field(default=None, description="Account tied to the context")

Check warning on line 13 in infrahub_sdk/context.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/context.py#L13

Added line #L13 was not covered by tests
3 changes: 3 additions & 0 deletions infrahub_sdk/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from .client import InfrahubClient
from .context import RequestContext
from .node import InfrahubNode
from .store import NodeStore

Expand All @@ -29,6 +30,7 @@
params: dict | None = None,
convert_query_response: bool = False,
logger: logging.Logger | None = None,
request_context: RequestContext | None = None,
) -> None:
self.query = query
self.branch = branch
Expand All @@ -44,6 +46,7 @@
self.infrahub_node = infrahub_node
self.convert_query_response = convert_query_response
self.logger = logger if logger else logging.getLogger("infrahub.tasks")
self.request_context = request_context

Check warning on line 49 in infrahub_sdk/generator.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/generator.py#L49

Added line #L49 was not covered by tests

@property
def store(self) -> NodeStore:
Expand Down
86 changes: 66 additions & 20 deletions infrahub_sdk/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing_extensions import Self

from .client import InfrahubClient, InfrahubClientSync
from .context import RequestContext
from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI
from .types import Order

Expand Down Expand Up @@ -766,6 +767,16 @@
Attribute(name=attr_name, schema=attr_schema, data=attr_data),
)

def _get_request_context(self, request_context: RequestContext | None = None) -> dict[str, Any] | None:
if request_context:
return request_context.model_dump(exclude_none=True)

client: InfrahubClient | InfrahubClientSync | None = getattr(self, "_client", None)
if not client or not client.request_context:

Check warning on line 775 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L775

Added line #L775 was not covered by tests
return None

return client.request_context.model_dump(exclude_none=True)

def _init_relationships(self, data: dict | None = None) -> None:
pass

Expand Down Expand Up @@ -794,7 +805,12 @@
def get_raw_graphql_data(self) -> dict | None:
return self._data

def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: bool = False) -> dict[str, dict]: # noqa: C901
def _generate_input_data( # noqa: C901
self,

Check warning on line 809 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L809

Added line #L809 was not covered by tests
exclude_unmodified: bool = False,
exclude_hfid: bool = False,
request_context: RequestContext | None = None,
) -> dict[str, dict]:
"""Generate a dictionary that represent the input data required by a mutation.

Returns:
Expand Down Expand Up @@ -872,7 +888,15 @@
elif self.hfid is not None and not exclude_hfid:
data["hfid"] = self.hfid

return {"data": {"data": data}, "variables": variables, "mutation_variables": mutation_variables}
mutation_payload = {"data": data}
if context_data := self._get_request_context(request_context=request_context):

Check warning on line 892 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L892

Added line #L892 was not covered by tests
mutation_payload["context"] = context_data

return {
"data": mutation_payload,

Check warning on line 896 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L896

Added line #L896 was not covered by tests
"variables": variables,
"mutation_variables": mutation_variables,
}

@staticmethod
def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, item: str) -> None:
Expand Down Expand Up @@ -1129,8 +1153,11 @@
content = await self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
return content

async def delete(self, timeout: int | None = None) -> None:
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
input_data = {"data": {"id": self.id}}
if context_data := self._get_request_context(request_context=request_context):
input_data["context"] = context_data

mutation_query = {"ok": None}
query = Mutation(
mutation=f"{self._schema.kind}Delete",
Expand All @@ -1145,12 +1172,16 @@
)

async def save(
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
self,
allow_upsert: bool = False,
update_group_context: bool | None = None,
timeout: int | None = None,
request_context: RequestContext | None = None,
) -> None:
if self._existing is False or allow_upsert is True:
await self.create(allow_upsert=allow_upsert, timeout=timeout)
await self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
else:
await self.update(timeout=timeout)
await self.update(timeout=timeout, request_context=request_context)

if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
update_group_context = True
Expand Down Expand Up @@ -1379,15 +1410,17 @@
await related_node.fetch(timeout=timeout)
setattr(self, rel_name, related_node)

async def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
async def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None

Check warning on line 1414 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L1413-L1414

Added lines #L1413 - L1414 were not covered by tests
) -> None:
mutation_query = self._generate_mutation_query()

if allow_upsert:
input_data = self._generate_input_data(exclude_hfid=False)
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
mutation_name = f"{self._schema.kind}Upsert"
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
else:
input_data = self._generate_input_data(exclude_hfid=True)
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)

Check warning on line 1423 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L1423

Added line #L1423 was not covered by tests
mutation_name = f"{self._schema.kind}Create"
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
query = Mutation(
Expand All @@ -1405,8 +1438,10 @@
)
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)

async def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
async def update(
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None:
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
mutation_query = self._generate_mutation_query()
mutation_name = f"{self._schema.kind}Update"

Expand Down Expand Up @@ -1645,8 +1680,11 @@
content = self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
return content

def delete(self, timeout: int | None = None) -> None:
def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
input_data = {"data": {"id": self.id}}
if context_data := self._get_request_context(request_context=request_context):
input_data["context"] = context_data

Check warning on line 1687 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L1687

Added line #L1687 was not covered by tests
mutation_query = {"ok": None}
query = Mutation(
mutation=f"{self._schema.kind}Delete",
Expand All @@ -1661,12 +1699,16 @@
)

def save(
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
self,
allow_upsert: bool = False,
update_group_context: bool | None = None,
timeout: int | None = None,
request_context: RequestContext | None = None,
) -> None:
if self._existing is False or allow_upsert is True:
self.create(allow_upsert=allow_upsert, timeout=timeout)
self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
else:
self.update(timeout=timeout)
self.update(timeout=timeout, request_context=request_context)

if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
update_group_context = True
Expand Down Expand Up @@ -1890,15 +1932,17 @@
related_node.fetch(timeout=timeout)
setattr(self, rel_name, related_node)

def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None

Check warning on line 1936 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L1935-L1936

Added lines #L1935 - L1936 were not covered by tests
) -> None:
mutation_query = self._generate_mutation_query()

if allow_upsert:
input_data = self._generate_input_data(exclude_hfid=False)
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
mutation_name = f"{self._schema.kind}Upsert"
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
else:
input_data = self._generate_input_data(exclude_hfid=True)
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)

Check warning on line 1945 in infrahub_sdk/node.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/node.py#L1945

Added line #L1945 was not covered by tests
mutation_name = f"{self._schema.kind}Create"
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
query = Mutation(
Expand All @@ -1917,8 +1961,10 @@
)
self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)

def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
def update(
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None:
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
mutation_query = self._generate_mutation_query()
mutation_name = f"{self._schema.kind}Update"

Expand Down
43 changes: 32 additions & 11 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
if TYPE_CHECKING:
import ipaddress

from .context import RequestContext
from .schema import MainSchemaTypes


Expand Down Expand Up @@ -169,13 +170,23 @@ def extract(self, params: dict[str, str]) -> dict[str, Any]: ...

@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
async def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ...
async def save(
self,
allow_upsert: bool = False,
update_group_context: bool | None = None,
timeout: int | None = None,
request_context: RequestContext | None = None,
) -> None: ...

async def delete(self) -> None: ...
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ...

async def update(self, do_full_update: bool) -> None: ...
async def update(
self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None
) -> None: ...

async def create(self, allow_upsert: bool = False) -> None: ...
async def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None: ...

async def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...

Expand All @@ -184,13 +195,23 @@ async def remove_relationships(self, relation_to_update: str, related_nodes: lis

@runtime_checkable
class CoreNodeSync(CoreNodeBase, Protocol):
def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ...

def delete(self) -> None: ...

def update(self, do_full_update: bool) -> None: ...

def create(self, allow_upsert: bool = False) -> None: ...
def save(
self,
allow_upsert: bool = False,
update_group_context: bool | None = None,
timeout: int | None = None,
request_context: RequestContext | None = None,
) -> None: ...

def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ...

def update(
self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None
) -> None: ...

def create(
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
) -> None: ...

def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...

Expand Down
10 changes: 8 additions & 2 deletions tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from infrahub_sdk.exceptions import NodeNotFoundError
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync

async_client_methods = [method for method in dir(InfrahubClient) if not method.startswith("_")]
sync_client_methods = [method for method in dir(InfrahubClientSync) if not method.startswith("_")]
excluded_methods = ["request_context"]

async_client_methods = [
method for method in dir(InfrahubClient) if not method.startswith("_") and method not in excluded_methods
]
sync_client_methods = [
method for method in dir(InfrahubClientSync) if not method.startswith("_") and method not in excluded_methods
]

batch_client_types = [
("standard", False),
Expand Down