From 902e99d900bce4487983c0e85d86e80e0b39ab5d Mon Sep 17 00:00:00 2001 From: Jonathan Katz <44128041+jkatz01@users.noreply.github.com> Date: Thu, 19 Mar 2026 09:50:18 -0400 Subject: [PATCH] Add automation_type filter to count policies endpoint (#42007) **Related issue:** Resolves #41987 # Checklist for submitter ## Testing - [x] Added/updated automated tests - [ ] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [x] QA'd all new/changed functionality manually - Tested with the "scripts" filter and >20 policies with that automation, and together with #41991 the policy count and pagination is correct --- server/datastore/mysql/policies.go | 46 ++++++++++---- server/datastore/mysql/policies_test.go | 81 ++++++++++++++++++------- server/fleet/datastore.go | 8 +-- server/fleet/service.go | 4 +- server/mock/datastore_mock.go | 24 ++++---- server/mock/service/service_mock.go | 12 ++-- server/service/global_policies.go | 2 +- server/service/team_policies.go | 11 ++-- 8 files changed, 124 insertions(+), 64 deletions(-) diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 72e4ee87d3b..4a3cc87f8df 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -37,7 +37,7 @@ const policyCols = ` p.id, p.team_id, p.resolution, p.name, p.query, p.description, p.author_id, p.platforms, p.created_at, p.updated_at, p.critical, p.calendar_events_enabled, p.software_installer_id, p.script_id, - p.vpp_apps_teams_id, p.conditional_access_enabled, p.type, + p.vpp_apps_teams_id, p.conditional_access_enabled, p.type, p.patch_software_title_id ` @@ -847,7 +847,7 @@ func getInheritedPoliciesForTeam(ctx context.Context, q sqlx.QueryerContext, tea // CountPolicies returns the total number of team policies. // If teamID is nil, it returns the total number of global policies. -func (ds *Datastore) CountPolicies(ctx context.Context, teamID *uint, matchQuery string) (int, error) { +func (ds *Datastore) CountPolicies(ctx context.Context, teamID *uint, matchQuery string, automationType string) (int, error) { var ( query string args []interface{} @@ -861,6 +861,18 @@ func (ds *Datastore) CountPolicies(ctx context.Context, teamID *uint, matchQuery args = append(args, *teamID) } + if teamID != nil { + automationFilter, filterArgs, err := ds.createAutomationClause(ctx, automationType, *teamID) + if err != nil { + return 0, ctxerr.Wrap(ctx, err, "build automation filter clause") + } + + query += automationFilter + if len(filterArgs) > 0 { + args = append(args, filterArgs...) + } + } + // We must normalize the name for full Unicode support (Unicode equivalence). match := norm.NFC.String(matchQuery) query, args = searchLike(query, args, match, policySearchColumns...) @@ -873,18 +885,28 @@ func (ds *Datastore) CountPolicies(ctx context.Context, teamID *uint, matchQuery return count, nil } -func (ds *Datastore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) { +func (ds *Datastore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string, automationType string) (int, error) { var args []interface{} query := `SELECT count(*) FROM policies p WHERE (p.team_id = ? OR p.team_id IS NULL)` args = append(args, teamID) + automationFilter, filterArgs, err := ds.createAutomationClause(ctx, automationType, teamID) + if err != nil { + return 0, ctxerr.Wrap(ctx, err, "build automation filter clause") + } + + query += automationFilter + if len(filterArgs) > 0 { + args = append(args, filterArgs...) + } + // We must normalize the name for full Unicode support (Unicode equivalence). match := norm.NFC.String(matchQuery) query, args = searchLike(query, args, match, policySearchColumns...) var count int - err := sqlx.GetContext(ctx, ds.reader(ctx), &count, query, args...) + err = sqlx.GetContext(ctx, ds.reader(ctx), &count, query, args...) if err != nil { return 0, ctxerr.Wrap(ctx, err, "counting merged team policies") } @@ -1144,8 +1166,8 @@ func newTeamPolicy(ctx context.Context, db sqlx.ExtContext, teamID uint, authorI return policyDB(ctx, db, policyID, &teamID) } -func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationFilter string) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) { - filterClause, filterArgs, err := ds.createAutomationClause(ctx, automationFilter, teamID) +func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationType string) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) { + filterClause, filterArgs, err := ds.createAutomationClause(ctx, automationType, teamID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "build automation filter clause") } @@ -1162,10 +1184,10 @@ func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint, opts fle return teamPolicies, inheritedPolicies, err } -func (ds *Datastore) ListMergedTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, automationFilter string) ([]*fleet.Policy, error) { +func (ds *Datastore) ListMergedTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, automationType string) ([]*fleet.Policy, error) { var args []interface{} - automationClause, filterArgs, err := ds.createAutomationClause(ctx, automationFilter, teamID) + automationFilter, filterArgs, err := ds.createAutomationClause(ctx, automationType, teamID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build automation filter clause") } @@ -1184,7 +1206,7 @@ func (ds *Datastore) ListMergedTeamPolicies(ctx context.Context, teamID uint, op AND (p.team_id IS NOT NULL OR ps.inherited_team_id = ?) WHERE (p.team_id = ? OR p.team_id IS NULL) %s - `, automationClause) + `, automationFilter) args = append(args, teamID, teamID) if len(filterArgs) > 0 { @@ -2572,9 +2594,9 @@ func (ds *Datastore) GetPatchPolicy(ctx context.Context, teamID *uint, titleID u return &policy, nil } -func (ds *Datastore) createAutomationClause(ctx context.Context, automationFilter string, teamID uint) (string, []any, error) { +func (ds *Datastore) createAutomationClause(ctx context.Context, automationType string, teamID uint) (string, []any, error) { // TODO: improve filtering by "other" - if automationFilter == "other" { + if automationType == "other" { team, err := ds.TeamLite(ctx, teamID) if err != nil { return "", nil, ctxerr.Wrap(ctx, err, "getting team config") @@ -2592,7 +2614,7 @@ func (ds *Datastore) createAutomationClause(ctx context.Context, automationFilte return clause, args, nil } - switch automationFilter { + switch automationType { case "software": return " AND (p.software_installer_id IS NOT NULL OR p.vpp_apps_teams_id IS NOT NULL OR p.type = 'patch')", nil, nil case "scripts": diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 4664bd5f3ff..4c2e4063c35 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -3523,15 +3523,15 @@ func testCountPolicies(t *testing.T, ds *Datastore) { require.NoError(t, err) // no policies - globalCount, err := ds.CountPolicies(ctx, nil, "") + globalCount, err := ds.CountPolicies(ctx, nil, "", "") require.NoError(t, err) assert.Equal(t, 0, globalCount) - teamCount, err := ds.CountPolicies(ctx, &tm.ID, "") + teamCount, err := ds.CountPolicies(ctx, &tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 0, teamCount) - mergedCount, err := ds.CountMergedTeamPolicies(ctx, tm.ID, "") + mergedCount, err := ds.CountMergedTeamPolicies(ctx, tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 0, mergedCount) @@ -3541,15 +3541,15 @@ func testCountPolicies(t *testing.T, ds *Datastore) { require.NoError(t, err) } - globalCount, err = ds.CountPolicies(ctx, nil, "") + globalCount, err = ds.CountPolicies(ctx, nil, "", "") require.NoError(t, err) assert.Equal(t, 10, globalCount) - teamCount, err = ds.CountPolicies(ctx, &tm.ID, "") + teamCount, err = ds.CountPolicies(ctx, &tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 0, teamCount) - mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "") + mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 10, mergedCount) @@ -3559,30 +3559,35 @@ func testCountPolicies(t *testing.T, ds *Datastore) { require.NoError(t, err) } - teamCount, err = ds.CountPolicies(ctx, &tm.ID, "") + teamCount, err = ds.CountPolicies(ctx, &tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 5, teamCount) - globalCount, err = ds.CountPolicies(ctx, nil, "") + globalCount, err = ds.CountPolicies(ctx, nil, "", "") require.NoError(t, err) assert.Equal(t, 10, globalCount) - mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "") + mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "", "") require.NoError(t, err) assert.Equal(t, 15, mergedCount) // test filter - globalCount, err = ds.CountPolicies(ctx, nil, "global policy 1") + globalCount, err = ds.CountPolicies(ctx, nil, "global policy 1", "") require.NoError(t, err) assert.Equal(t, 1, globalCount) - teamCount, err = ds.CountPolicies(ctx, &tm.ID, "team policy 1") + teamCount, err = ds.CountPolicies(ctx, &tm.ID, "team policy 1", "") require.NoError(t, err) assert.Equal(t, 1, teamCount) - mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "policy 1") + mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "policy 1", "") require.NoError(t, err) assert.Equal(t, 2, mergedCount) + + // test automation filter doesn't affect global policy count + globalCount, err = ds.CountPolicies(ctx, nil, "", "scripts") + require.NoError(t, err) + assert.Equal(t, 10, globalCount) } func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { @@ -3812,10 +3817,10 @@ func testPoliciesNameUnicode(t *testing.T, ds *Datastore) { assert.Equal(t, equivalentNames[0], inheritedPolicies[0].Name) // CountPolicies - count, err := ds.CountPolicies(context.Background(), &team.ID, equivalentNames[1]) + count, err := ds.CountPolicies(context.Background(), &team.ID, equivalentNames[1], "") assert.NoError(t, err) assert.Equal(t, 1, count) - count, err = ds.CountPolicies(context.Background(), nil, equivalentNames[1]) + count, err = ds.CountPolicies(context.Background(), nil, equivalentNames[1], "") assert.NoError(t, err) assert.Equal(t, 1, count) } @@ -5651,7 +5656,7 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { require.Equal(t, p, globalPolicy) ids = append(ids, globalPolicy.ID) } - c, err := ds.CountPolicies(ctx, nil, "") + c, err := ds.CountPolicies(ctx, nil, "", "") require.NoError(t, err) require.Equal(t, 2, c) globalPoliciesByID, err := ds.PoliciesByID(ctx, ids) @@ -5685,10 +5690,10 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { require.NoError(t, err) require.Len(t, teamPoliciesByID, 1) require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0]) - c, err = ds.CountMergedTeamPolicies(ctx, team1.ID, "") + c, err = ds.CountMergedTeamPolicies(ctx, team1.ID, "", "") require.NoError(t, err) require.Equal(t, 3, c) - c, err = ds.CountPolicies(ctx, &team1.ID, "") + c, err = ds.CountPolicies(ctx, &team1.ID, "", "") require.NoError(t, err) require.Equal(t, 1, c) mergedTeamPolicies, err := ds.ListMergedTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, "") @@ -5733,10 +5738,10 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { require.Len(t, teamPoliciesByID, 2) require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0]) require.Equal(t, teamPoliciesByID[teamPolicies[1].ID], teamPolicies[1]) - c, err = ds.CountMergedTeamPolicies(ctx, team2.ID, "") + c, err = ds.CountMergedTeamPolicies(ctx, team2.ID, "", "") require.NoError(t, err) require.Equal(t, 4, c) - c, err = ds.CountPolicies(ctx, &team2.ID, "") + c, err = ds.CountPolicies(ctx, &team2.ID, "", "") require.NoError(t, err) require.Equal(t, 2, c) mergedTeamPolicies, err = ds.ListMergedTeamPolicies(ctx, team2.ID, fleet.ListOptions{}, "") @@ -5784,10 +5789,10 @@ func testTeamPoliciesNoTeam(t *testing.T, ds *Datastore) { require.Len(t, teamPoliciesByID, 2) require.Equal(t, teamPoliciesByID[teamPolicies[0].ID], teamPolicies[0]) require.Equal(t, teamPoliciesByID[teamPolicies[1].ID], teamPolicies[1]) - c, err = ds.CountMergedTeamPolicies(ctx, fleet.PolicyNoTeamID, "") + c, err = ds.CountMergedTeamPolicies(ctx, fleet.PolicyNoTeamID, "", "") require.NoError(t, err) require.Equal(t, 4, c) - c, err = ds.CountPolicies(ctx, ptr.Uint(fleet.PolicyNoTeamID), "") + c, err = ds.CountPolicies(ctx, ptr.Uint(fleet.PolicyNoTeamID), "", "") require.NoError(t, err) require.Equal(t, 2, c) mergedTeamPolicies, err = ds.ListMergedTeamPolicies(ctx, fleet.PolicyNoTeamID, fleet.ListOptions{}, "") @@ -7353,7 +7358,11 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { assert.Equal(t, teamWebhookPolicy.ID, merged[6].ID) assert.Equal(t, teamPatchPolicy.ID, merged[7].ID) - // Test filters + mergedCount, err := ds.CountMergedTeamPolicies(ctx, 0, "", "") + require.NoError(t, err) + assert.Equal(t, 8, mergedCount) + + // Test software merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", OrderDirection: fleet.OrderAscending, @@ -7364,6 +7373,11 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { assert.Equal(t, teamAppStorePolicy.ID, merged[1].ID) assert.Equal(t, teamPatchPolicy.ID, merged[2].ID) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "software") + require.NoError(t, err) + assert.Equal(t, 3, mergedCount) + + // Test scripts merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", OrderDirection: fleet.OrderAscending, @@ -7372,6 +7386,11 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { require.Len(t, merged, 1) assert.Equal(t, teamScriptPolicy.ID, merged[0].ID) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "scripts") + require.NoError(t, err) + assert.Equal(t, 1, mergedCount) + + // Test calendar merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", OrderDirection: fleet.OrderAscending, @@ -7380,6 +7399,11 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { require.Len(t, merged, 1) assert.Equal(t, teamCalendarPolicy.ID, merged[0].ID) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "calendar") + require.NoError(t, err) + assert.Equal(t, 1, mergedCount) + + // Test conditional_access merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", OrderDirection: fleet.OrderAscending, @@ -7388,6 +7412,11 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { require.Len(t, merged, 1) assert.Equal(t, teamConditionalPolicy.ID, merged[0].ID) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "conditional_access") + require.NoError(t, err) + assert.Equal(t, 1, mergedCount) + + // Test other merged, err = ds.ListMergedTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", OrderDirection: fleet.OrderAscending, @@ -7396,6 +7425,10 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { require.Len(t, merged, 1) assert.Equal(t, teamWebhookPolicy.ID, merged[0].ID) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, 0, "", "other") + require.NoError(t, err) + assert.Equal(t, 1, mergedCount) + // Test not merged policies, _, err := ds.ListTeamPolicies(ctx, 0, fleet.ListOptions{ OrderKey: "name", @@ -7406,4 +7439,8 @@ func testTeamPolicyAutomationFilter(t *testing.T, ds *Datastore) { assert.Equal(t, teamInstallerPolicy.ID, policies[0].ID) assert.Equal(t, teamAppStorePolicy.ID, policies[1].ID) assert.Equal(t, teamPatchPolicy.ID, policies[2].ID) + + mergedCount, err = ds.CountPolicies(ctx, ptr.Uint(0), "", "software") + require.NoError(t, err) + assert.Equal(t, 3, mergedCount) } diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 94bc052bf3e..0cce1d0e146 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -828,8 +828,8 @@ type Datastore interface { ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error) DeleteGlobalPolicies(ctx context.Context, ids []uint) ([]uint, error) - CountPolicies(ctx context.Context, teamID *uint, matchQuery string) (int, error) - CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) + CountPolicies(ctx context.Context, teamID *uint, matchQuery string, automationType string) (int, error) + CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string, automationType string) (int, error) UpdateHostPolicyCounts(ctx context.Context) error PolicyQueriesForHost(ctx context.Context, host *Host) (map[string]string, error) @@ -905,8 +905,8 @@ type Datastore interface { // Team Policies NewTeamPolicy(ctx context.Context, teamID uint, authorID *uint, args PolicyPayload) (*Policy, error) - ListTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, iopts ListOptions, automationFilter string) (teamPolicies, inheritedPolicies []*Policy, err error) - ListMergedTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, automationFilter string) ([]*Policy, error) + ListTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, iopts ListOptions, automationType string) (teamPolicies, inheritedPolicies []*Policy, err error) + ListMergedTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, automationType string) ([]*Policy, error) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) TeamPolicy(ctx context.Context, teamID uint, policyID uint) (*Policy, error) diff --git a/server/fleet/service.go b/server/fleet/service.go index 074733abbc6..c4e1545f523 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -829,11 +829,11 @@ type Service interface { // Team Policies NewTeamPolicy(ctx context.Context, teamID uint, p NewTeamPolicyPayload) (*Policy, error) - ListTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, iopts ListOptions, mergeInherited bool, automationFilter string) (teamPolicies, inheritedPolicies []*Policy, err error) + ListTeamPolicies(ctx context.Context, teamID uint, opts ListOptions, iopts ListOptions, mergeInherited bool, automationType string) (teamPolicies, inheritedPolicies []*Policy, err error) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) ModifyTeamPolicy(ctx context.Context, teamID uint, id uint, p ModifyPolicyPayload) (*Policy, error) GetTeamPolicyByIDQueries(ctx context.Context, teamID uint, policyID uint) (*Policy, error) - CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, int, error) + CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool, automationType string) (int, int, error) // ///////////////////////////////////////////////////////////////////////////// // Geolocation diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 96ad024782c..aa7a8215776 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -623,9 +623,9 @@ type PoliciesByIDFunc func(ctx context.Context, ids []uint) (map[uint]*fleet.Pol type DeleteGlobalPoliciesFunc func(ctx context.Context, ids []uint) ([]uint, error) -type CountPoliciesFunc func(ctx context.Context, teamID *uint, matchQuery string) (int, error) +type CountPoliciesFunc func(ctx context.Context, teamID *uint, matchQuery string, automationType string) (int, error) -type CountMergedTeamPoliciesFunc func(ctx context.Context, teamID uint, matchQuery string) (int, error) +type CountMergedTeamPoliciesFunc func(ctx context.Context, teamID uint, matchQuery string, automationType string) (int, error) type UpdateHostPolicyCountsFunc func(ctx context.Context) error @@ -695,9 +695,9 @@ type ListOutOfDateCalendarEventsFunc func(ctx context.Context, t time.Time) ([]* type NewTeamPolicyFunc func(ctx context.Context, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) -type ListTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationFilter string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) +type ListTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationType string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) -type ListMergedTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, automationFilter string) ([]*fleet.Policy, error) +type ListMergedTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, automationType string) ([]*fleet.Policy, error) type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) @@ -6642,18 +6642,18 @@ func (s *DataStore) DeleteGlobalPolicies(ctx context.Context, ids []uint) ([]uin return s.DeleteGlobalPoliciesFunc(ctx, ids) } -func (s *DataStore) CountPolicies(ctx context.Context, teamID *uint, matchQuery string) (int, error) { +func (s *DataStore) CountPolicies(ctx context.Context, teamID *uint, matchQuery string, automationType string) (int, error) { s.mu.Lock() s.CountPoliciesFuncInvoked = true s.mu.Unlock() - return s.CountPoliciesFunc(ctx, teamID, matchQuery) + return s.CountPoliciesFunc(ctx, teamID, matchQuery, automationType) } -func (s *DataStore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) { +func (s *DataStore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string, automationType string) (int, error) { s.mu.Lock() s.CountMergedTeamPoliciesFuncInvoked = true s.mu.Unlock() - return s.CountMergedTeamPoliciesFunc(ctx, teamID, matchQuery) + return s.CountMergedTeamPoliciesFunc(ctx, teamID, matchQuery, automationType) } func (s *DataStore) UpdateHostPolicyCounts(ctx context.Context) error { @@ -6894,18 +6894,18 @@ func (s *DataStore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *ui return s.NewTeamPolicyFunc(ctx, teamID, authorID, args) } -func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationFilter string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) { +func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, automationType string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) { s.mu.Lock() s.ListTeamPoliciesFuncInvoked = true s.mu.Unlock() - return s.ListTeamPoliciesFunc(ctx, teamID, opts, iopts, automationFilter) + return s.ListTeamPoliciesFunc(ctx, teamID, opts, iopts, automationType) } -func (s *DataStore) ListMergedTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, automationFilter string) ([]*fleet.Policy, error) { +func (s *DataStore) ListMergedTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, automationType string) ([]*fleet.Policy, error) { s.mu.Lock() s.ListMergedTeamPoliciesFuncInvoked = true s.mu.Unlock() - return s.ListMergedTeamPoliciesFunc(ctx, teamID, opts, automationFilter) + return s.ListMergedTeamPoliciesFunc(ctx, teamID, opts, automationType) } func (s *DataStore) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) { diff --git a/server/mock/service/service_mock.go b/server/mock/service/service_mock.go index f73d9a09ae2..a9dad33e226 100644 --- a/server/mock/service/service_mock.go +++ b/server/mock/service/service_mock.go @@ -511,7 +511,7 @@ type ListSoftwareByCVEFunc func(ctx context.Context, cve string, teamID *uint) ( type NewTeamPolicyFunc func(ctx context.Context, teamID uint, p fleet.NewTeamPolicyPayload) (*fleet.Policy, error) -type ListTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, mergeInherited bool, automationFilter string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) +type ListTeamPoliciesFunc func(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, mergeInherited bool, automationType string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) @@ -519,7 +519,7 @@ type ModifyTeamPolicyFunc func(ctx context.Context, teamID uint, id uint, p flee type GetTeamPolicyByIDQueriesFunc func(ctx context.Context, teamID uint, policyID uint) (*fleet.Policy, error) -type CountTeamPoliciesFunc func(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, int, error) +type CountTeamPoliciesFunc func(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool, automationType string) (int, int, error) type LookupGeoIPFunc func(ctx context.Context, ip string) *fleet.GeoLocation @@ -3930,11 +3930,11 @@ func (s *Service) NewTeamPolicy(ctx context.Context, teamID uint, p fleet.NewTea return s.NewTeamPolicyFunc(ctx, teamID, p) } -func (s *Service) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, mergeInherited bool, automationFilter string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) { +func (s *Service) ListTeamPolicies(ctx context.Context, teamID uint, opts fleet.ListOptions, iopts fleet.ListOptions, mergeInherited bool, automationType string) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) { s.mu.Lock() s.ListTeamPoliciesFuncInvoked = true s.mu.Unlock() - return s.ListTeamPoliciesFunc(ctx, teamID, opts, iopts, mergeInherited, automationFilter) + return s.ListTeamPoliciesFunc(ctx, teamID, opts, iopts, mergeInherited, automationType) } func (s *Service) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) { @@ -3958,11 +3958,11 @@ func (s *Service) GetTeamPolicyByIDQueries(ctx context.Context, teamID uint, pol return s.GetTeamPolicyByIDQueriesFunc(ctx, teamID, policyID) } -func (s *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, int, error) { +func (s *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool, automationType string) (int, int, error) { s.mu.Lock() s.CountTeamPoliciesFuncInvoked = true s.mu.Unlock() - return s.CountTeamPoliciesFunc(ctx, teamID, matchQuery, mergeInherited) + return s.CountTeamPoliciesFunc(ctx, teamID, matchQuery, mergeInherited, automationType) } func (s *Service) LookupGeoIP(ctx context.Context, ip string) *fleet.GeoLocation { diff --git a/server/service/global_policies.go b/server/service/global_policies.go index aab99440a30..dbc1f9c9fea 100644 --- a/server/service/global_policies.go +++ b/server/service/global_policies.go @@ -206,7 +206,7 @@ func (svc Service) CountGlobalPolicies(ctx context.Context, matchQuery string) ( return 0, err } - count, err := svc.ds.CountPolicies(ctx, nil, matchQuery) + count, err := svc.ds.CountPolicies(ctx, nil, matchQuery, "") if err != nil { return 0, err } diff --git a/server/service/team_policies.go b/server/service/team_policies.go index 5b7d84b937d..ae92e3d96d3 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -339,6 +339,7 @@ type countTeamPoliciesRequest struct { ListOptions fleet.ListOptions `url:"list_options"` TeamID uint `url:"fleet_id"` MergeInherited bool `query:"merge_inherited,optional"` + AutomationType string `query:"automation_type,optional"` } type countTeamPoliciesResponse struct { @@ -351,14 +352,14 @@ func (r countTeamPoliciesResponse) Error() error { return r.Err } func countTeamPoliciesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (fleet.Errorer, error) { req := request.(*countTeamPoliciesRequest) - count, inheritedCount, err := svc.CountTeamPolicies(ctx, req.TeamID, req.ListOptions.MatchQuery, req.MergeInherited) + count, inheritedCount, err := svc.CountTeamPolicies(ctx, req.TeamID, req.ListOptions.MatchQuery, req.MergeInherited, req.AutomationType) if err != nil { return countTeamPoliciesResponse{Err: err}, nil } return countTeamPoliciesResponse{Count: count, InheritedPolicyCount: inheritedCount}, nil } -func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, int, error) { +func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool, automationType string) (int, int, error) { if err := svc.authz.Authorize(ctx, &fleet.Policy{ PolicyData: fleet.PolicyData{ TeamID: ptr.Uint(teamID), @@ -374,18 +375,18 @@ func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQue } if mergeInherited { - count, err := svc.ds.CountMergedTeamPolicies(ctx, teamID, matchQuery) + count, err := svc.ds.CountMergedTeamPolicies(ctx, teamID, matchQuery, automationType) if err != nil { return 0, 0, err } - inheritedCount, err := svc.ds.CountPolicies(ctx, nil, matchQuery) + inheritedCount, err := svc.ds.CountPolicies(ctx, nil, matchQuery, automationType) if err != nil { return 0, 0, err } return count, inheritedCount, nil } - count, err := svc.ds.CountPolicies(ctx, &teamID, matchQuery) + count, err := svc.ds.CountPolicies(ctx, &teamID, matchQuery, automationType) if err != nil { return 0, 0, err }