Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/count-method-filters.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added possibility to use filters for the SDK client's count method
14 changes: 12 additions & 2 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,21 @@ async def count(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> int:
"""Return the number of nodes of a given kind."""
filters = kwargs
schema = await self.schema.get(kind=kind, branch=branch)

branch = branch or self.default_branch
if at:
at = Timestamp(at)

response = await self.execute_graphql(
query=Query(query={schema.kind: {"count": None}}).render(), branch_name=branch, at=at, timeout=timeout
query=Query(query={schema.kind: {"count": None, "@filters": filters}}).render(),
branch_name=branch,
at=at,
timeout=timeout,
)
return int(response.get(schema.kind, {}).get("count", 0))

Expand Down Expand Up @@ -1651,16 +1656,21 @@ def count(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> int:
"""Return the number of nodes of a given kind."""
filters = kwargs
schema = self.schema.get(kind=kind, branch=branch)

branch = branch or self.default_branch
if at:
at = Timestamp(at)

response = self.execute_graphql(
query=Query(query={schema.kind: {"count": None}}).render(), branch_name=branch, at=at, timeout=timeout
query=Query(query={schema.kind: {"count": None, "@filters": filters}}).render(),
branch_name=branch,
at=at,
timeout=timeout,
)
return int(response.get(schema.kind, {}).get("count", 0))

Expand Down
8 changes: 8 additions & 0 deletions tests/integration/test_infrahub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ async def test_create_branch_async(self, client: InfrahubClient, base_dataset):
task_id = await client.branch.create(branch_name="new-branch-2", wait_until_completion=False)
assert isinstance(task_id, str)

async def test_count(self, client: InfrahubClient, base_dataset):
count = await client.count(kind=TESTING_PERSON)
assert count == 3

async def test_count_with_filter(self, client: InfrahubClient, base_dataset):
count = await client.count(kind=TESTING_PERSON, name__values=["Liam Walker", "Ethan Carter"])
assert count == 2

# async def test_get_generic_filter_source(self, client: InfrahubClient, base_dataset):
# admin = await client.get(kind="CoreAccount", name__value="admin")

Expand Down
10 changes: 10 additions & 0 deletions tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ async def test_method_count(clients, mock_query_repository_count, client_type):
assert count == 5


@pytest.mark.parametrize("client_type", client_types)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this test is really checking anything since the payload is static
ideally we should add an integration tests for it

async def test_method_count_with_filter(clients, mock_query_repository_count, client_type): # pylint: disable=unused-argument
if client_type == "standard":
count = await clients.standard.count(kind="CoreRepository", name__value="test")
else:
count = clients.sync.count(kind="CoreRepository", name__value="test")

assert count == 5


@pytest.mark.parametrize("client_type", client_types)
async def test_method_get_version(clients, mock_query_infrahub_version, client_type): # pylint: disable=unused-argument
if client_type == "standard":
Expand Down
Loading