From 1d38c310ca331561aa88ca691eedce65e369c888 Mon Sep 17 00:00:00 2001 From: Michael Wilson Date: Thu, 1 Feb 2024 17:01:33 -0500 Subject: [PATCH] [v15] Fix list all access list members/reviews pagination. (#37673) * Fix list all access list members pagination. Two things were happening that was preventing the cache of access list members from initializing if the number of members was greater than the default page size of 200: 1. The incorrect page token was being used in the ListAllAccessListMembers call. 2. The pagination key for listing all of these members was incorrect due to the use of sub-prefixes. Both issues have been corrected. The ListAllAccessListMembers call in the access list service now uses its own custom pagination logic. * Realized reviews suffer from the same problem. * Simplify nextKey calc. * Use a function that returns the next element and allow the caller to construct their own next key. * Use *T instead of making Resource comparable. * Cleanup code, use nils instead of named return, eliminate named returns to avoid confusion. --- lib/services/local/access_list.go | 28 +++-- lib/services/local/access_list_test.go | 118 ++++++++++++++++++++- lib/services/local/generic/generic.go | 21 +++- lib/services/local/generic/generic_test.go | 43 ++++++++ 4 files changed, 196 insertions(+), 14 deletions(-) diff --git a/lib/services/local/access_list.go b/lib/services/local/access_list.go index 9d3e6a9ca4182..a48b119bb1bc2 100644 --- a/lib/services/local/access_list.go +++ b/lib/services/local/access_list.go @@ -271,10 +271,16 @@ func (a *AccessListService) ListAccessListMembers(ctx context.Context, accessLis } // ListAllAccessListMembers returns a paginated list of all access list members for all access lists. -func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) (members []*accesslist.AccessListMember, nextToken string, err error) { - // Locks are not used here as these operations are more likely to be used by the cache. - // Lists all access list members for all access lists. - return a.memberService.ListResources(ctx, pageSize, nextToken) +func (a *AccessListService) ListAllAccessListMembers(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.AccessListMember, string, error) { + members, next, err := a.memberService.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + var nextKey string + if next != nil { + nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name + } + return members, nextKey, nil } // GetAccessListMember returns the specified access list member resource. @@ -485,10 +491,16 @@ func (a *AccessListService) ListAccessListReviews(ctx context.Context, accessLis } // ListAllAccessListReviews will list access list reviews for all access lists. -func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) (reviews []*accesslist.Review, nextToken string, err error) { - // Locks are not used here as these operations are more likely to be used by the cache. - // Lists all access list reviews for all access lists. - return a.reviewService.ListResources(ctx, pageSize, pageToken) +func (a *AccessListService) ListAllAccessListReviews(ctx context.Context, pageSize int, pageToken string) ([]*accesslist.Review, string, error) { + reviews, next, err := a.reviewService.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + if err != nil { + return nil, "", trace.Wrap(err) + } + var nextKey string + if next != nil { + nextKey = (*next).Spec.AccessList + string(backend.Separator) + (*next).Metadata.Name + } + return reviews, nextKey, nil } // CreateAccessListReview will create a new review for an access list. diff --git a/lib/services/local/access_list_test.go b/lib/services/local/access_list_test.go index e0e597119def4..e59990aded6f3 100644 --- a/lib/services/local/access_list_test.go +++ b/lib/services/local/access_list_test.go @@ -20,6 +20,8 @@ package local import ( "context" + "fmt" + "strconv" "testing" "time" @@ -954,7 +956,7 @@ func newAccessListReview(t *testing.T, accessList, name string) *accesslist.Revi review, err := accesslist.NewReview( header.Metadata{ - Name: "test-access-list-review", + Name: name, }, accesslist.ReviewSpec{ AccessList: accessList, @@ -1353,6 +1355,120 @@ func TestChangingOwnershipModeIsAnError(t *testing.T) { } } +func TestAccessListService_ListAllAccessListMembers(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service := newAccessListService(t, mem, clock, true /* igsEnabled */) + + const numAccessLists = 10 + const numAccessListMembersPerAccessList = 250 + totalMembers := numAccessLists * numAccessListMembersPerAccessList + + // Create several access lists. + expectedMembers := make([]*accesslist.AccessListMember, totalMembers) + for i := 0; i < numAccessLists; i++ { + alName := strconv.Itoa(i) + _, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock)) + require.NoError(t, err) + + for j := 0; j < numAccessListMembersPerAccessList; j++ { + member := newAccessListMember(t, alName, fmt.Sprintf("%03d", j)) + expectedMembers[i*numAccessListMembersPerAccessList+j] = member + _, err := service.UpsertAccessListMember(ctx, member) + require.NoError(t, err) + } + } + + allMembers := make([]*accesslist.AccessListMember, 0, totalMembers) + var nextToken string + for { + var members []*accesslist.AccessListMember + var err error + members, nextToken, err = service.ListAllAccessListMembers(ctx, 0, nextToken) + require.NoError(t, err) + + allMembers = append(allMembers, members...) + + if nextToken == "" { + break + } + } + + require.Empty(t, cmp.Diff(expectedMembers, allMembers, cmpopts.IgnoreFields(header.Metadata{}, "ID", "Revision"))) +} + +func TestAccessListService_ListAllAccessListReviews(t *testing.T) { + ctx := context.Background() + clock := clockwork.NewFakeClock() + + mem, err := memory.New(memory.Config{ + Context: ctx, + Clock: clock, + }) + require.NoError(t, err) + + service := newAccessListService(t, mem, clock, true /* igsEnabled */) + + const numAccessLists = 10 + const numAccessListReviewsPerAccessList = 250 + totalReviews := numAccessLists * numAccessListReviewsPerAccessList + + // Create several access lists. + expectedReviews := make([]*accesslist.Review, totalReviews) + for i := 0; i < numAccessLists; i++ { + alName := strconv.Itoa(i) + _, err := service.UpsertAccessList(ctx, newAccessList(t, alName, clock)) + require.NoError(t, err) + + for j := 0; j < numAccessListReviewsPerAccessList; j++ { + review, err := accesslist.NewReview( + header.Metadata{ + Name: strconv.Itoa(j), + }, + accesslist.ReviewSpec{ + AccessList: alName, + Reviewers: []string{ + "user1", + }, + ReviewDate: time.Now(), + }, + ) + require.NoError(t, err) + review, _, err = service.CreateAccessListReview(ctx, review) + expectedReviews[i*numAccessListReviewsPerAccessList+j] = review + require.NoError(t, err) + } + } + + allReviews := make([]*accesslist.Review, 0, totalReviews) + var nextToken string + for { + var reviews []*accesslist.Review + var err error + reviews, nextToken, err = service.ListAllAccessListReviews(ctx, 0, nextToken) + require.NoError(t, err) + + allReviews = append(allReviews, reviews...) + + if nextToken == "" { + break + } + } + + require.Empty(t, cmp.Diff(expectedReviews, allReviews, cmpopts.IgnoreFields(header.Metadata{}, "ID", "Revision"), cmpopts.SortSlices( + func(r1, r2 *accesslist.Review) bool { + return r1.GetName() < r2.GetName() + }), + )) +} + func newAccessListService(t *testing.T, mem *memory.Memory, clock clockwork.Clock, igsEnabled bool) *AccessListService { t.Helper() diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index d20bd76a58aca..a7998796ac475 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -146,6 +146,17 @@ func (s *Service[T]) GetResources(ctx context.Context) ([]T, error) { // ListResources returns a paginated list of resources. func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken string) ([]T, string, error) { + resources, next, err := s.ListResourcesReturnNextResource(ctx, pageSize, pageToken) + var nextKey string + if next != nil { + nextKey = backend.GetPaginationKey(*next) + } + return resources, nextKey, trace.Wrap(err) +} + +// ListResourcesReturnNextResource returns a paginated list of resources. The next resource is returned, which allows consumers to construct +// the next pagination key as appropriate. +func (s *Service[T]) ListResourcesReturnNextResource(ctx context.Context, pageSize int, pageToken string) ([]T, *T, error) { rangeStart := backend.Key(s.backendPrefix, pageToken) rangeEnd := backend.RangeEnd(backend.ExactKey(s.backendPrefix)) @@ -159,26 +170,26 @@ func (s *Service[T]) ListResources(ctx context.Context, pageSize int, pageToken // no filter provided get the range directly result, err := s.backend.GetRange(ctx, rangeStart, rangeEnd, limit) if err != nil { - return nil, "", trace.Wrap(err) + return nil, nil, trace.Wrap(err) } out := make([]T, 0, len(result.Items)) for _, item := range result.Items { resource, err := s.unmarshalFunc(item.Value, services.WithRevision(item.Revision), services.WithResourceID(item.ID)) if err != nil { - return nil, "", trace.Wrap(err) + return nil, nil, trace.Wrap(err) } out = append(out, resource) } - var nextKey string + var next *T if len(out) > pageSize { - nextKey = backend.GetPaginationKey(out[len(out)-1]) + next = &out[pageSize] // Truncate the last item that was used to determine next row existence. out = out[:pageSize] } - return out, nextKey, nil + return out, next, nil } // GetResource returns the specified resource. diff --git a/lib/services/local/generic/generic_test.go b/lib/services/local/generic/generic_test.go index 313220156df74..29752e5985180 100644 --- a/lib/services/local/generic/generic_test.go +++ b/lib/services/local/generic/generic_test.go @@ -274,3 +274,46 @@ func TestGenericCRUD(t *testing.T) { require.Empty(t, nextToken) require.Empty(t, out) } + +func TestGenericListResourcesReturnNextResource(t *testing.T) { + ctx := context.Background() + + memBackend, err := memory.New(memory.Config{ + Context: ctx, + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + service, err := NewService(&ServiceConfig[*testResource]{ + Backend: memBackend, + ResourceKind: "generic resource", + PageLimit: 200, + BackendPrefix: "generic_prefix", + UnmarshalFunc: unmarshalResource, + MarshalFunc: marshalResource, + }) + require.NoError(t, err) + + // Create a couple test resources. + r1 := newTestResource("r1") + r2 := newTestResource("r2") + + _, err = service.WithPrefix("a-unique-prefix").UpsertResource(ctx, r1) + require.NoError(t, err) + _, err = service.WithPrefix("another-unique-prefix").UpsertResource(ctx, r2) + require.NoError(t, err) + + page, next, err := service.ListResourcesReturnNextResource(ctx, 1, "") + require.NoError(t, err) + require.Empty(t, cmp.Diff([]*testResource{r1}, page, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + )) + require.NotNil(t, next) + + page, next, err = service.ListResourcesReturnNextResource(ctx, 1, "another-unique-prefix"+string(backend.Separator)+backend.GetPaginationKey(*next)) + require.NoError(t, err) + require.Empty(t, cmp.Diff([]*testResource{r2}, page, + cmpopts.IgnoreFields(types.Metadata{}, "ID"), + )) + require.Nil(t, next) +}