Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Start DRY'ing filtered paginate code #16099

Merged
merged 17 commits into from May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 1 addition & 8 deletions types/query/collections_pagination.go
Expand Up @@ -53,9 +53,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
predicateFunc func(key K, value V) (include bool, err error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
if pageReq == nil {
pageReq = &PageRequest{}
}
pageReq = initPageRequestDefaults(pageReq)

offset := pageReq.Offset
key := pageReq.Key
Expand All @@ -67,11 +65,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit
countTotal = true
}

var (
results []collections.KeyValue[K, V]
pageRes *PageResponse
Expand Down
215 changes: 84 additions & 131 deletions types/query/filtered_pagination.go
Expand Up @@ -22,106 +22,115 @@ func FilteredPaginate(
pageRequest *PageRequest,
onResult func(key, value []byte, accumulate bool) (bool, error),
) (*PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}

offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
pageRequest = initPageRequestDefaults(pageRequest)

if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit

// count total results when the limit is zero/not supplied
countTotal = true
}

if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
err error
)

var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()

if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}

if iterator.Error() != nil {
return nil, iterator.Error()
}

hit, err := onResult(iterator.Key(), iterator.Value(), true)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}

if hit {
numHits++
}
}

return &PageResponse{
NextKey: nextKey,
}, nil
}

iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()

end := offset + limit

var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }

for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, iterator.Error()
}

accumulate := numHits >= offset && numHits < end
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
numHits, err = processResult(iterator, numHits, onResult, accumulateFn)
if err != nil {
return nil, err
}

if hit {
numHits++
}

if numHits == end+1 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I got this wrong right here, because it looks at iterator.Key(). (Which is a different value in this loop)

This is a bit annoying to make consistent with the old behavior. Will adjust the de-duplicated loop to get this consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically the old behavior would ideally worked with numHits = end right here.

But it actually made nextKey skip to the next hit, which is kind of annoying to get to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current PR has this problem fixed.

In a later breaking release when redesigning the Filtered paginate, I would recommend changing this behavior. (As it will both be more correct from a performance expectation perspective, and code quality cleanup. We could delete all the extra for loops everywhere.)

However this PR is fully non-breaking.

if nextKey == nil {
nextKey = iterator.Key()
}

if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}

res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}

return res, nil
}

func processResult(iterator types.Iterator, numHits uint64, onResult func(key, value []byte, accumulate bool) (bool, error), accumulateFn func(numHits uint64) bool) (uint64, error) {
if iterator.Error() != nil {
return numHits, iterator.Error()
}

accumulate := accumulateFn(numHits)
hit, err := onResult(iterator.Key(), iterator.Value(), accumulate)
if err != nil {
return numHits, err
}

if hit {
numHits++
}

return numHits, nil
}

func genericProcessResult[T, F proto.Message](iterator types.Iterator, numHits uint64, onResult func(key []byte, value T) (F, error), accumulateFn func(numHits uint64) bool,
constructor func() T, cdc codec.BinaryCodec, results []F,
) ([]F, uint64, error) {
if iterator.Error() != nil {
return results, numHits, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
if err != nil {
return results, numHits, err
}

val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return results, numHits, err
}

if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if accumulateFn(numHits) {
results = append(results, val)
}
numHits++
}

return results, numHits, nil
}

// GenericFilteredPaginate does pagination of all the results in the PrefixStore based on the
// provided PageRequest. `onResult` should be used to filter or transform the results.
// `c` is a constructor function that needs to return a new instance of the type T (this is to
Expand All @@ -137,119 +146,63 @@ func GenericFilteredPaginate[T, F proto.Message](
onResult func(key []byte, value T) (F, error),
constructor func() T,
) ([]F, *PageResponse, error) {
// if the PageRequest is nil, use default PageRequest
if pageRequest == nil {
pageRequest = &PageRequest{}
}

offset := pageRequest.Offset
key := pageRequest.Key
limit := pageRequest.Limit
countTotal := pageRequest.CountTotal
reverse := pageRequest.Reverse
pageRequest = initPageRequestDefaults(pageRequest)
results := []F{}

if offset > 0 && key != nil {
if pageRequest.Offset > 0 && pageRequest.Key != nil {
return results, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

if limit == 0 {
limit = DefaultLimit

// count total results when the limit is zero/not supplied
countTotal = true
}

if len(key) != 0 {
iterator := getIterator(prefixStore, key, reverse)
defer iterator.Close()
var (
numHits uint64
nextKey []byte
err error
)

var (
numHits uint64
nextKey []byte
)
iterator := getIterator(prefixStore, pageRequest.Key, pageRequest.Reverse)
defer iterator.Close()

if len(pageRequest.Key) != 0 {
accumulateFn := func(_ uint64) bool { return true }
for ; iterator.Valid(); iterator.Next() {
if numHits == limit {
if numHits == pageRequest.Limit {
nextKey = iterator.Key()
break
}

if iterator.Error() != nil {
return nil, nil, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}

val, err := onResult(iterator.Key(), protoMsg)
if err != nil {
return nil, nil, err
}

if proto.Size(val) != 0 {
results = append(results, val)
numHits++
}
}

return results, &PageResponse{
NextKey: nextKey,
}, nil
}

iterator := getIterator(prefixStore, nil, reverse)
defer iterator.Close()

end := offset + limit

var (
numHits uint64
nextKey []byte
)
end := pageRequest.Offset + pageRequest.Limit
accumulateFn := func(numHits uint64) bool { return numHits >= pageRequest.Offset && numHits < end }

for ; iterator.Valid(); iterator.Next() {
if iterator.Error() != nil {
return nil, nil, iterator.Error()
}

protoMsg := constructor()

err := cdc.Unmarshal(iterator.Value(), protoMsg)
if err != nil {
return nil, nil, err
}

val, err := onResult(iterator.Key(), protoMsg)
results, numHits, err = genericProcessResult(iterator, numHits, onResult, accumulateFn, constructor, cdc, results)
if err != nil {
return nil, nil, err
}

if proto.Size(val) != 0 {
// Previously this was the "accumulate" flag
if numHits >= offset && numHits < end {
results = append(results, val)
}
numHits++
}

if numHits == end+1 {
if nextKey == nil {
nextKey = iterator.Key()
}

if !countTotal {
if !pageRequest.CountTotal {
break
}
}
}

res := &PageResponse{NextKey: nextKey}
if countTotal {
if pageRequest.CountTotal {
res.Total = numHits
}

Expand Down