Skip to content

Commit

Permalink
Change the type of slice_nqs and slice_topks from int32_t[] to int64_t[]
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Aug 29, 2022
1 parent c4a9e13 commit b219652
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 50 deletions.
15 changes: 10 additions & 5 deletions internal/core/src/segcore/Reduce.h
Expand Up @@ -10,10 +10,11 @@
// or implied. See the License for the specific language governing permissions and limitations under the License

#pragma once
#include <cstdint>
#include <vector>

#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>

#include "utils/Status.h"
#include "common/type_c.h"
Expand All @@ -31,9 +32,13 @@ class ReduceHelper {
public:
explicit ReduceHelper(std::vector<SearchResult*>& search_results,
milvus::query::Plan* plan,
std::vector<int64_t>& slice_nqs,
std::vector<int64_t>& slice_topKs)
: search_results_(search_results), plan_(plan), slice_nqs_(slice_nqs), slice_topKs_(slice_topKs) {
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t slice_num)
: search_results_(search_results),
plan_(plan),
slice_nqs_(slice_nqs, slice_nqs + slice_num),
slice_topKs_(slice_topKs, slice_topKs + slice_num) {
Initialize();
}

Expand Down
16 changes: 4 additions & 12 deletions internal/core/src/segcore/reduce_c.cpp
Expand Up @@ -26,9 +26,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchPlan c_plan,
CSearchResult* c_search_results,
int64_t num_segments,
int32_t* slice_nqs,
int32_t* slice_topKs,
int32_t num_slices) {
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices) {
try {
// get SearchResult and SearchPlan
auto plan = static_cast<milvus::query::Plan*>(c_plan);
Expand All @@ -38,15 +38,7 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
}

// get slice_nqs and slice_topKs
auto slice_nqs_vec = std::vector<int64_t>(num_slices);
auto slice_topKs_vec = std::vector<int64_t>(num_slices);
for (int i = 0; i < num_slices; i++) {
slice_nqs_vec[i] = slice_nqs[i];
slice_topKs_vec[i] = slice_topKs[i];
}

auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs_vec, slice_topKs_vec);
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, plan, slice_nqs, slice_topKs, num_slices);
reduce_helper.Reduce();
reduce_helper.Marshal();

Expand Down
6 changes: 3 additions & 3 deletions internal/core/src/segcore/reduce_c.h
Expand Up @@ -24,9 +24,9 @@ ReduceSearchResultsAndFillData(CSearchResultDataBlobs* cSearchResultDataBlobs,
CSearchPlan c_plan,
CSearchResult* search_results,
int64_t num_segments,
int32_t* slice_nqs,
int32_t* slice_topKs,
int32_t num_slices);
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t num_slices);

CStatus
GetSearchResultDataBlob(CProto* searchResultDataBlob,
Expand Down
28 changes: 14 additions & 14 deletions internal/core/unittest/test_c_api.cpp
Expand Up @@ -1137,8 +1137,8 @@ TEST(CApiTest, ReudceNullResult) {
dataset.timestamps_.push_back(1);

{
auto slice_nqs = std::vector<int32_t>{10};
auto slice_topKs = std::vector<int32_t>{1};
auto slice_nqs = std::vector<int64_t>{10};
auto slice_topKs = std::vector<int64_t>{1};
std::vector<CSearchResult> results;
CSearchResult res;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res, -1);
Expand Down Expand Up @@ -1214,8 +1214,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
dataset.timestamps_.push_back(1);

{
auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
std::vector<CSearchResult> results;
CSearchResult res1, res2;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
Expand All @@ -1239,8 +1239,8 @@ TEST(CApiTest, ReduceRemoveDuplicates) {
int nq1 = num_queries / 3;
int nq2 = num_queries / 3;
int nq3 = num_queries - nq1 - nq2;
auto slice_nqs = std::vector<int32_t>{nq1, nq2, nq3};
auto slice_topKs = std::vector<int32_t>{topK / 2, topK, topK};
auto slice_nqs = std::vector<int64_t>{nq1, nq2, nq3};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK, topK};
std::vector<CSearchResult> results;
CSearchResult res1, res2, res3;
status = Search(segment, plan, placeholderGroup, dataset.timestamps_[0], &res1, -1);
Expand Down Expand Up @@ -1324,13 +1324,13 @@ testReduceSearchWithExpr(int N, int topK, int num_queries) {
results.push_back(res1);
results.push_back(res2);

auto slice_nqs = std::vector<int32_t>{num_queries / 2, num_queries / 2};
auto slice_nqs = std::vector<int64_t>{num_queries / 2, num_queries / 2};
if (num_queries == 1) {
slice_nqs = std::vector<int32_t>{num_queries};
slice_nqs = std::vector<int64_t>{num_queries};
}
auto slice_topKs = std::vector<int32_t>{topK / 2, topK};
auto slice_topKs = std::vector<int64_t>{topK / 2, topK};
if (topK == 1) {
slice_topKs = std::vector<int32_t>{topK, topK};
slice_topKs = std::vector<int64_t>{topK, topK};
}

// 1. reduce
Expand Down Expand Up @@ -2749,8 +2749,8 @@ TEST(CApiTest, Indexing_With_binary_Predicate_Term) {
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);

auto slice_nqs = std::vector<int32_t>{num_queries};
auto slice_topKs = std::vector<int32_t>{topK};
auto slice_nqs = std::vector<int64_t>{num_queries};
auto slice_topKs = std::vector<int64_t>{topK};

CSearchResultDataBlobs cSearchResultData;
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),
Expand Down Expand Up @@ -2915,8 +2915,8 @@ TEST(CApiTest, Indexing_Expr_With_binary_Predicate_Term) {
std::vector<CSearchResult> results;
results.push_back(c_search_result_on_bigIndex);

auto slice_nqs = std::vector<int32_t>{num_queries};
auto slice_topKs = std::vector<int32_t>{topK};
auto slice_nqs = std::vector<int64_t>{num_queries};
auto slice_topKs = std::vector<int64_t>{topK};

CSearchResultDataBlobs cSearchResultData;
status = ReduceSearchResultsAndFillData(&cSearchResultData, plan, results.data(), results.size(), slice_nqs.data(),
Expand Down
2 changes: 1 addition & 1 deletion internal/querynode/mock_test.go
Expand Up @@ -1632,7 +1632,7 @@ func checkSearchResult(nq int64, plan *SearchPlan, searchResult *SearchResult) e
if result.TopK != sliceTopKs[i] {
return fmt.Errorf("unexpected topK when checkSearchResult")
}
if result.NumQueries != int64(sInfo.sliceNQs[i]) {
if result.NumQueries != sInfo.sliceNQs[i] {
return fmt.Errorf("unexpected nq when checkSearchResult")
}
// search empty segment, return empty result.IDs
Expand Down
24 changes: 12 additions & 12 deletions internal/querynode/reduce.go
Expand Up @@ -28,8 +28,8 @@ import (
)

type sliceInfo struct {
sliceNQs []int32
sliceTopKs []int32
sliceNQs []int64
sliceTopKs []int64
}

// SearchResult contains a pointer to the search result in C++ memory
Expand All @@ -47,8 +47,8 @@ type RetrieveResult struct {

func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *sliceInfo {
sInfo := &sliceInfo{
sliceNQs: make([]int32, 0),
sliceTopKs: make([]int32, 0),
sliceNQs: make([]int64, 0),
sliceTopKs: make([]int64, 0),
}

if nqPerSlice == 0 {
Expand All @@ -57,20 +57,20 @@ func parseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *s

for i := 0; i < len(originNQs); i++ {
for j := 0; j < int(originNQs[i]/nqPerSlice); j++ {
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(nqPerSlice))
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
sInfo.sliceNQs = append(sInfo.sliceNQs, nqPerSlice)
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
}
if tailSliceSize := originNQs[i] % nqPerSlice; tailSliceSize > 0 {
sInfo.sliceNQs = append(sInfo.sliceNQs, int32(tailSliceSize))
sInfo.sliceTopKs = append(sInfo.sliceTopKs, int32(originTopKs[i]))
sInfo.sliceNQs = append(sInfo.sliceNQs, tailSliceSize)
sInfo.sliceTopKs = append(sInfo.sliceTopKs, originTopKs[i])
}
}

return sInfo
}

func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchResult,
numSegments int64, sliceNQs []int32, sliceTopKs []int32) (searchResultDataBlobs, error) {
numSegments int64, sliceNQs []int64, sliceTopKs []int64) (searchResultDataBlobs, error) {
if plan.cSearchPlan == nil {
return nil, fmt.Errorf("nil search plan")
}
Expand All @@ -92,9 +92,9 @@ func reduceSearchResultsAndFillData(plan *SearchPlan, searchResults []*SearchRes
}
cSearchResultPtr := (*C.CSearchResult)(&cSearchResults[0])
cNumSegments := C.int64_t(numSegments)
var cSliceNQSPtr = (*C.int32_t)(&sliceNQs[0])
var cSliceTopKSPtr = (*C.int32_t)(&sliceTopKs[0])
var cNumSlices = C.int32_t(len(sliceNQs))
var cSliceNQSPtr = (*C.int64_t)(&sliceNQs[0])
var cSliceTopKSPtr = (*C.int64_t)(&sliceTopKs[0])
var cNumSlices = C.int64_t(len(sliceNQs))
var cSearchResultDataBlobs searchResultDataBlobs
status := C.ReduceSearchResultsAndFillData(&cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
Expand Down
6 changes: 3 additions & 3 deletions internal/querynode/reduce_test.go
Expand Up @@ -37,8 +37,8 @@ func TestReduce_parseSliceInfo(t *testing.T) {
nqPerSlice := int64(2)
sInfo := parseSliceInfo(originNQs, originTopKs, nqPerSlice)

expectedSliceNQs := []int32{2, 2, 1, 2}
expectedSliceTopKs := []int32{10, 5, 5, 20}
expectedSliceNQs := []int64{2, 2, 1, 2}
expectedSliceTopKs := []int64{10, 5, 5, 20}
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceNQs, expectedSliceNQs))
assert.True(t, funcutil.SliceSetEqual(sInfo.sliceTopKs, expectedSliceTopKs))
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestReduce_Invalid(t *testing.T) {
assert.NoError(t, err)
searchResults := make([]*SearchResult, 0)
searchResults = append(searchResults, nil)
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int32{10}, []int32{10})
_, err = reduceSearchResultsAndFillData(searchReq.plan, searchResults, 1, []int64{10}, []int64{10})
assert.Error(t, err)
})
}

0 comments on commit b219652

Please sign in to comment.