Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/159.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add the ability to batch API queries for `all` and `filter` functions.
151 changes: 111 additions & 40 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ async def all(
fragment: bool = ...,
prefetch_relationships: bool = ...,
property: bool = ...,
parallel: bool = ...,
) -> list[SchemaType]: ...

@overload
Expand All @@ -576,6 +577,7 @@ async def all(
fragment: bool = ...,
prefetch_relationships: bool = ...,
property: bool = ...,
parallel: bool = ...,
) -> list[InfrahubNode]: ...

async def all(
Expand All @@ -592,6 +594,7 @@ async def all(
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
parallel: bool = False,
) -> list[InfrahubNode] | list[SchemaType]:
"""Retrieve all nodes of a given kind

Expand All @@ -607,6 +610,7 @@ async def all(
exclude (list[str], optional): List of attributes or relationships to exclude from the query.
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
parallel (bool, optional): Whether to use parallel processing for the query.

Returns:
list[InfrahubNode]: List of Nodes
Expand All @@ -624,6 +628,7 @@ async def all(
fragment=fragment,
prefetch_relationships=prefetch_relationships,
property=property,
parallel=parallel,
)

@overload
Expand All @@ -642,6 +647,7 @@ async def filters(
prefetch_relationships: bool = ...,
partial_match: bool = ...,
property: bool = ...,
parallel: bool = ...,
**kwargs: Any,
) -> list[SchemaType]: ...

Expand All @@ -661,6 +667,7 @@ async def filters(
prefetch_relationships: bool = ...,
partial_match: bool = ...,
property: bool = ...,
parallel: bool = ...,
**kwargs: Any,
) -> list[InfrahubNode]: ...

Expand All @@ -679,6 +686,7 @@ async def filters(
prefetch_relationships: bool = False,
partial_match: bool = False,
property: bool = False,
parallel: bool = False,
**kwargs: Any,
) -> list[InfrahubNode] | list[SchemaType]:
"""Retrieve nodes of a given kind based on provided filters.
Expand All @@ -696,32 +704,26 @@ async def filters(
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
partial_match (bool, optional): Allow partial match of filter criteria for the query.
parallel (bool, optional): Whether to use parallel processing for the query.
**kwargs (Any): Additional filter criteria for the query.

Returns:
list[InfrahubNodeSync]: List of Nodes that match the given filters.
"""
schema = await self.schema.get(kind=kind, branch=branch)

branch = branch or self.default_branch
schema = await self.schema.get(kind=kind, branch=branch)
if at:
at = Timestamp(at)

node = InfrahubNode(client=self, schema=schema, branch=branch)
filters = kwargs
pagination_size = self.pagination_size

nodes: list[InfrahubNode] = []
related_nodes: list[InfrahubNode] = []

has_remaining_items = True
page_number = 1

while has_remaining_items:
page_offset = (page_number - 1) * self.pagination_size

async def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelationsNode]:
"""Process a single page of results."""
query_data = await InfrahubNode(client=self, schema=schema, branch=branch).generate_query_data(
offset=offset or page_offset,
limit=limit or self.pagination_size,
limit=limit or pagination_size,
filters=filters,
include=include,
exclude=exclude,
Expand All @@ -746,14 +748,48 @@ async def filters(
prefetch_relationships=prefetch_relationships,
timeout=timeout,
)
nodes.extend(process_result["nodes"])
related_nodes.extend(process_result["related_nodes"])
return response, process_result

async def process_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
"""Process queries in parallel mode."""
nodes = []
related_nodes = []
batch_process = await self.create_batch()
count = await self.count(kind=schema.kind)
total_pages = (count + pagination_size - 1) // pagination_size

for page_number in range(1, total_pages + 1):
page_offset = (page_number - 1) * pagination_size
batch_process.add(task=process_page, node=node, page_offset=page_offset, page_number=page_number)

remaining_items = response[schema.kind].get("count", 0) - (page_offset + self.pagination_size)
if remaining_items < 0 or offset is not None or limit is not None:
has_remaining_items = False
async for _, response in batch_process.execute():
nodes.extend(response[1]["nodes"])
related_nodes.extend(response[1]["related_nodes"])

page_number += 1
return nodes, related_nodes

async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
"""Process queries without parallel mode."""
nodes = []
related_nodes = []
has_remaining_items = True
page_number = 1

while has_remaining_items:
page_offset = (page_number - 1) * pagination_size
response, process_result = await process_page(page_offset, page_number)

nodes.extend(process_result["nodes"])
related_nodes.extend(process_result["related_nodes"])
remaining_items = response[schema.kind].get("count", 0) - (page_offset + pagination_size)
if remaining_items < 0 or offset is not None or limit is not None:
has_remaining_items = False
page_number += 1

return nodes, related_nodes

# Select parallel or non-parallel processing
nodes, related_nodes = await (process_batch() if parallel else process_non_batch())

if populate_store:
for node in nodes:
Expand All @@ -763,7 +799,6 @@ async def filters(
for node in related_nodes:
if node.id:
self.store.set(key=node.id, node=node)

return nodes

def clone(self) -> InfrahubClient:
Expand Down Expand Up @@ -1602,6 +1637,7 @@ def all(
fragment: bool = ...,
prefetch_relationships: bool = ...,
property: bool = ...,
parallel: bool = ...,
) -> list[SchemaTypeSync]: ...

@overload
Expand All @@ -1619,6 +1655,7 @@ def all(
fragment: bool = ...,
prefetch_relationships: bool = ...,
property: bool = ...,
parallel: bool = ...,
) -> list[InfrahubNodeSync]: ...

def all(
Expand All @@ -1635,6 +1672,7 @@ def all(
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
parallel: bool = False,
) -> list[InfrahubNodeSync] | list[SchemaTypeSync]:
"""Retrieve all nodes of a given kind

Expand All @@ -1650,6 +1688,7 @@ def all(
exclude (list[str], optional): List of attributes or relationships to exclude from the query.
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
parallel (bool, optional): Whether to use parallel processing for the query.

Returns:
list[InfrahubNodeSync]: List of Nodes
Expand All @@ -1667,6 +1706,7 @@ def all(
fragment=fragment,
prefetch_relationships=prefetch_relationships,
property=property,
parallel=parallel,
)

def _process_nodes_and_relationships(
Expand Down Expand Up @@ -1720,6 +1760,7 @@ def filters(
prefetch_relationships: bool = ...,
partial_match: bool = ...,
property: bool = ...,
parallel: bool = ...,
**kwargs: Any,
) -> list[SchemaTypeSync]: ...

Expand All @@ -1739,6 +1780,7 @@ def filters(
prefetch_relationships: bool = ...,
partial_match: bool = ...,
property: bool = ...,
parallel: bool = ...,
**kwargs: Any,
) -> list[InfrahubNodeSync]: ...

Expand All @@ -1757,6 +1799,7 @@ def filters(
prefetch_relationships: bool = False,
partial_match: bool = False,
property: bool = False,
parallel: bool = False,
**kwargs: Any,
) -> list[InfrahubNodeSync] | list[SchemaTypeSync]:
"""Retrieve nodes of a given kind based on provided filters.
Expand All @@ -1774,32 +1817,25 @@ def filters(
fragment (bool, optional): Flag to use GraphQL fragments for generic schemas.
prefetch_relationships (bool, optional): Flag to indicate whether to prefetch related node data.
partial_match (bool, optional): Allow partial match of filter criteria for the query.
parallel (bool, optional): Whether to use parallel processing for the query.
**kwargs (Any): Additional filter criteria for the query.

Returns:
list[InfrahubNodeSync]: List of Nodes that match the given filters.
"""
schema = self.schema.get(kind=kind, branch=branch)

branch = branch or self.default_branch
schema = self.schema.get(kind=kind, branch=branch)
node = InfrahubNodeSync(client=self, schema=schema, branch=branch)
if at:
at = Timestamp(at)

node = InfrahubNodeSync(client=self, schema=schema, branch=branch)
filters = kwargs
pagination_size = self.pagination_size

nodes: list[InfrahubNodeSync] = []
related_nodes: list[InfrahubNodeSync] = []

has_remaining_items = True
page_number = 1

while has_remaining_items:
page_offset = (page_number - 1) * self.pagination_size

def process_page(page_offset: int, page_number: int) -> tuple[dict, ProcessRelationsNodeSync]:
"""Process a single page of results."""
query_data = InfrahubNodeSync(client=self, schema=schema, branch=branch).generate_query_data(
offset=offset or page_offset,
limit=limit or self.pagination_size,
limit=limit or pagination_size,
filters=filters,
include=include,
exclude=exclude,
Expand All @@ -1824,14 +1860,50 @@ def filters(
prefetch_relationships=prefetch_relationships,
timeout=timeout,
)
nodes.extend(process_result["nodes"])
related_nodes.extend(process_result["related_nodes"])
return response, process_result

def process_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]]:
"""Process queries in parallel mode."""
nodes = []
related_nodes = []
batch_process = self.create_batch()

remaining_items = response[schema.kind].get("count", 0) - (page_offset + self.pagination_size)
if remaining_items < 0 or offset is not None or limit is not None:
has_remaining_items = False
count = self.count(kind=schema.kind)
total_pages = (count + pagination_size - 1) // pagination_size

page_number += 1
for page_number in range(1, total_pages + 1):
page_offset = (page_number - 1) * pagination_size
batch_process.add(task=process_page, node=node, page_offset=page_offset, page_number=page_number)

for _, response in batch_process.execute():
nodes.extend(response[1]["nodes"])
related_nodes.extend(response[1]["related_nodes"])

return nodes, related_nodes

def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]]:
"""Process queries without parallel mode."""
nodes = []
related_nodes = []
has_remaining_items = True
page_number = 1

while has_remaining_items:
page_offset = (page_number - 1) * pagination_size
response, process_result = process_page(page_offset, page_number)

nodes.extend(process_result["nodes"])
related_nodes.extend(process_result["related_nodes"])

remaining_items = response[schema.kind].get("count", 0) - (page_offset + pagination_size)
if remaining_items < 0 or offset is not None or limit is not None:
has_remaining_items = False
page_number += 1

return nodes, related_nodes

# Select parallel or non-parallel processing
nodes, related_nodes = process_batch() if parallel else process_non_batch()

if populate_store:
for node in nodes:
Expand All @@ -1841,7 +1913,6 @@ def filters(
for node in related_nodes:
if node.id:
self.store.set(key=node.id, node=node)

return nodes

@overload
Expand Down
Loading
Loading