diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index bcab0ea1..85cfdc16 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -58,6 +58,8 @@ if TYPE_CHECKING: from types import TracebackType + from .context import RequestContext + SchemaType = TypeVar("SchemaType", bound=CoreNode) SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync) @@ -139,6 +141,7 @@ def __init__( self.identifier = self.config.identifier self.group_context: InfrahubGroupContext | InfrahubGroupContextSync self._initialize() + self._request_context: RequestContext | None = None def _initialize(self) -> None: """Sets the properties for each version of the client""" @@ -153,6 +156,14 @@ def _echo(self, url: str, query: str, variables: dict | None = None) -> None: if variables: print(f"VARIABLES:\n{ujson.dumps(variables, indent=4)}\n") + @property + def request_context(self) -> RequestContext | None: + return self._request_context + + @request_context.setter + def request_context(self, request_context: RequestContext) -> None: + self._request_context = request_context + def start_tracking( self, identifier: str | None = None, diff --git a/infrahub_sdk/context.py b/infrahub_sdk/context.py new file mode 100644 index 00000000..201a9ef9 --- /dev/null +++ b/infrahub_sdk/context.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class ContextAccount(BaseModel): + id: str = Field(..., description="The ID of the account") + + +class RequestContext(BaseModel): + """The context can be used to override settings such as the account within mutations.""" + + account: ContextAccount | None = Field(default=None, description="Account tied to the context") diff --git a/infrahub_sdk/generator.py b/infrahub_sdk/generator.py index 3ba6c767..831d9c98 100644 --- a/infrahub_sdk/generator.py +++ b/infrahub_sdk/generator.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .client import InfrahubClient + from .context import RequestContext from .node import InfrahubNode from .store import NodeStore @@ -29,6 +30,7 @@ def __init__( params: dict | None = None, convert_query_response: bool = False, logger: logging.Logger | None = None, + request_context: RequestContext | None = None, ) -> None: self.query = query self.branch = branch @@ -44,6 +46,7 @@ def __init__( self.infrahub_node = infrahub_node self.convert_query_response = convert_query_response self.logger = logger if logger else logging.getLogger("infrahub.tasks") + self.request_context = request_context @property def store(self) -> NodeStore: diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index f34a7b5a..5119c415 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -22,6 +22,7 @@ from typing_extensions import Self from .client import InfrahubClient, InfrahubClientSync + from .context import RequestContext from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI from .types import Order @@ -766,6 +767,16 @@ def _init_attributes(self, data: dict | None = None) -> None: Attribute(name=attr_name, schema=attr_schema, data=attr_data), ) + def _get_request_context(self, request_context: RequestContext | None = None) -> dict[str, Any] | None: + if request_context: + return request_context.model_dump(exclude_none=True) + + client: InfrahubClient | InfrahubClientSync | None = getattr(self, "_client", None) + if not client or not client.request_context: + return None + + return client.request_context.model_dump(exclude_none=True) + def _init_relationships(self, data: dict | None = None) -> None: pass @@ -794,7 +805,12 @@ def is_resource_pool(self) -> bool: def get_raw_graphql_data(self) -> dict | None: return self._data - def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: bool = False) -> dict[str, dict]: # noqa: C901 + def _generate_input_data( # noqa: C901 + self, + exclude_unmodified: bool = False, + exclude_hfid: bool = False, + request_context: RequestContext | None = None, + ) -> dict[str, dict]: """Generate a dictionary that represent the input data required by a mutation. Returns: @@ -872,7 +888,15 @@ def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: b elif self.hfid is not None and not exclude_hfid: data["hfid"] = self.hfid - return {"data": {"data": data}, "variables": variables, "mutation_variables": mutation_variables} + mutation_payload = {"data": data} + if context_data := self._get_request_context(request_context=request_context): + mutation_payload["context"] = context_data + + return { + "data": mutation_payload, + "variables": variables, + "mutation_variables": mutation_variables, + } @staticmethod def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, item: str) -> None: @@ -1129,8 +1153,11 @@ async def artifact_fetch(self, name: str) -> str | dict[str, Any]: content = await self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined] return content - async def delete(self, timeout: int | None = None) -> None: + async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: input_data = {"data": {"id": self.id}} + if context_data := self._get_request_context(request_context=request_context): + input_data["context"] = context_data + mutation_query = {"ok": None} query = Mutation( mutation=f"{self._schema.kind}Delete", @@ -1145,12 +1172,16 @@ async def delete(self, timeout: int | None = None) -> None: ) async def save( - self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None + self, + allow_upsert: bool = False, + update_group_context: bool | None = None, + timeout: int | None = None, + request_context: RequestContext | None = None, ) -> None: if self._existing is False or allow_upsert is True: - await self.create(allow_upsert=allow_upsert, timeout=timeout) + await self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context) else: - await self.update(timeout=timeout) + await self.update(timeout=timeout, request_context=request_context) if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True @@ -1379,15 +1410,17 @@ async def _process_mutation_result( await related_node.fetch(timeout=timeout) setattr(self, rel_name, related_node) - async def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None: + async def create( + self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: mutation_query = self._generate_mutation_query() if allow_upsert: - input_data = self._generate_input_data(exclude_hfid=False) + input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context) mutation_name = f"{self._schema.kind}Upsert" tracker = f"mutation-{str(self._schema.kind).lower()}-upsert" else: - input_data = self._generate_input_data(exclude_hfid=True) + input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context) mutation_name = f"{self._schema.kind}Create" tracker = f"mutation-{str(self._schema.kind).lower()}-create" query = Mutation( @@ -1405,8 +1438,10 @@ async def create(self, allow_upsert: bool = False, timeout: int | None = None) - ) await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) - async def update(self, do_full_update: bool = False, timeout: int | None = None) -> None: - input_data = self._generate_input_data(exclude_unmodified=not do_full_update) + async def update( + self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: + input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" @@ -1645,8 +1680,11 @@ def artifact_fetch(self, name: str) -> str | dict[str, Any]: content = self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined] return content - def delete(self, timeout: int | None = None) -> None: + def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: input_data = {"data": {"id": self.id}} + if context_data := self._get_request_context(request_context=request_context): + input_data["context"] = context_data + mutation_query = {"ok": None} query = Mutation( mutation=f"{self._schema.kind}Delete", @@ -1661,12 +1699,16 @@ def delete(self, timeout: int | None = None) -> None: ) def save( - self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None + self, + allow_upsert: bool = False, + update_group_context: bool | None = None, + timeout: int | None = None, + request_context: RequestContext | None = None, ) -> None: if self._existing is False or allow_upsert is True: - self.create(allow_upsert=allow_upsert, timeout=timeout) + self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context) else: - self.update(timeout=timeout) + self.update(timeout=timeout, request_context=request_context) if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING: update_group_context = True @@ -1890,15 +1932,17 @@ def _process_mutation_result( related_node.fetch(timeout=timeout) setattr(self, rel_name, related_node) - def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None: + def create( + self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: mutation_query = self._generate_mutation_query() if allow_upsert: - input_data = self._generate_input_data(exclude_hfid=False) + input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context) mutation_name = f"{self._schema.kind}Upsert" tracker = f"mutation-{str(self._schema.kind).lower()}-upsert" else: - input_data = self._generate_input_data(exclude_hfid=True) + input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context) mutation_name = f"{self._schema.kind}Create" tracker = f"mutation-{str(self._schema.kind).lower()}-create" query = Mutation( @@ -1917,8 +1961,10 @@ def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None ) self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout) - def update(self, do_full_update: bool = False, timeout: int | None = None) -> None: - input_data = self._generate_input_data(exclude_unmodified=not do_full_update) + def update( + self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: + input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context) mutation_query = self._generate_mutation_query() mutation_name = f"{self._schema.kind}Update" diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index 4c227117..c634d37f 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: import ipaddress + from .context import RequestContext from .schema import MainSchemaTypes @@ -169,13 +170,23 @@ def extract(self, params: dict[str, str]) -> dict[str, Any]: ... @runtime_checkable class CoreNode(CoreNodeBase, Protocol): - async def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ... + async def save( + self, + allow_upsert: bool = False, + update_group_context: bool | None = None, + timeout: int | None = None, + request_context: RequestContext | None = None, + ) -> None: ... - async def delete(self) -> None: ... + async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ... - async def update(self, do_full_update: bool) -> None: ... + async def update( + self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: ... - async def create(self, allow_upsert: bool = False) -> None: ... + async def create( + self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: ... async def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... @@ -184,13 +195,23 @@ async def remove_relationships(self, relation_to_update: str, related_nodes: lis @runtime_checkable class CoreNodeSync(CoreNodeBase, Protocol): - def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ... - - def delete(self) -> None: ... - - def update(self, do_full_update: bool) -> None: ... - - def create(self, allow_upsert: bool = False) -> None: ... + def save( + self, + allow_upsert: bool = False, + update_group_context: bool | None = None, + timeout: int | None = None, + request_context: RequestContext | None = None, + ) -> None: ... + + def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ... + + def update( + self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: ... + + def create( + self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None + ) -> None: ... def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ... diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 4d849221..9f0f4e33 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -7,8 +7,14 @@ from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync -async_client_methods = [method for method in dir(InfrahubClient) if not method.startswith("_")] -sync_client_methods = [method for method in dir(InfrahubClientSync) if not method.startswith("_")] +excluded_methods = ["request_context"] + +async_client_methods = [ + method for method in dir(InfrahubClient) if not method.startswith("_") and method not in excluded_methods +] +sync_client_methods = [ + method for method in dir(InfrahubClientSync) if not method.startswith("_") and method not in excluded_methods +] batch_client_types = [ ("standard", False),