diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 7eedc621da84..0989871cd1f1 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -28,6 +28,9 @@ type rankParams struct { // parseSearchInfo returns QueryInfo and offset func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, ignoreOffset bool) (*planpb.QueryInfo, int64, error) { + //0. parse iterator field + isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + // 1. parse offset and real topk topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) if err != nil { @@ -38,7 +41,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) } if err := validateLimit(topK); err != nil { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) + if isIterator == "True" { + topK = Params.QuotaConfig.TopKLimit.GetAsInt64() + //1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem + //2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here + } else { + return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) + } } var offset int64 @@ -109,8 +118,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } - // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search - isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + // 6. disable groupBy for iterator and range search if isIterator == "True" && groupByFieldId > 0 { return nil, 0, merr.WrapErrParameterInvalid("", "", "Not allowed to do groupBy when doing iteration") diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index ffb95290f0f2..2397c3270aad 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -2104,6 +2104,26 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) + t.Run("check iterator and topK", func(t *testing.T) { + normalParam := getValidSearchParams() + normalParam = append(normalParam, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + resetSearchParamsValue(normalParam, TopKKey, `1024000`) + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: int64(101), + Name: "string_field", + }) + schema := &schemapb.CollectionSchema{ + Fields: fields, + } + info, _, err := parseSearchInfo(normalParam, schema, false) + assert.NotNil(t, info) + assert.NoError(t, err) + assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk) + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {