Skip to content

Commit

Permalink
[skip e2e] Optimize test_reduce (#18957)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>

Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Sep 1, 2022
1 parent 0e4e796 commit 4ded453
Showing 1 changed file with 65 additions and 78 deletions.
143 changes: 65 additions & 78 deletions internal/core/unittest/test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,110 +14,97 @@
#include <random>
#include <vector>

#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "query/SubSearchResult.h"

using namespace milvus;
using namespace milvus::query;

TEST(Reduce, SubQueryResult) {
int64_t num_queries = 512;
int64_t topk = 32;
int64_t iteration = 50;
int64_t round_decimal = 3;
constexpr int64_t limit = 100000000L;
auto metric_type = knowhere::metric::L2;
using queue_type = std::priority_queue<int64_t>;
using SubSearchResultUniq = std::unique_ptr<SubSearchResult>;

std::default_random_engine e(42);

std::vector<queue_type> ref_results(num_queries);
for (auto& ref_result : ref_results) {
for (int i = 0; i < topk; ++i) {
ref_result.push(limit);
std::unique_ptr<SubSearchResult>
GenSubSearchResult(const int64_t nq,
const int64_t topk,
const knowhere::MetricType &metric_type,
const int64_t round_decimal) {
constexpr int64_t limit = 1000000L;
bool is_ip = (metric_type == knowhere::metric::IP);
SubSearchResultUniq sub_result = std::make_unique<SubSearchResult>(nq, topk, metric_type, round_decimal);
std::vector<int64_t> ids;
std::vector<float> distances;
for (int n = 0; n < nq; ++n) {
for (int k = 0; k < topk; ++k) {
auto gen_x = e() % limit;
ids.push_back(gen_x);
distances.push_back(gen_x);
}
}
std::default_random_engine e(42);
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
for (int i = 0; i < iteration; ++i) {
std::vector<int64_t> ids;
std::vector<float> distances;
for (int n = 0; n < num_queries; ++n) {
for (int k = 0; k < topk; ++k) {
auto gen_x = e() % limit;
ref_results[n].push(gen_x);
ref_results[n].pop();
ids.push_back(gen_x);
distances.push_back(gen_x);
}
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk);
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk);
if (is_ip) {
std::sort(ids.begin() + n * topk, ids.begin() + (n + 1) * topk, std::greater<int64_t>());
std::sort(distances.begin() + n * topk, distances.begin() + (n + 1) * topk, std::greater<float>());
} else {
std::sort(ids.begin() + n * topk, ids.begin() + (n + 1) * topk);
std::sort(distances.begin() + n * topk, distances.begin() + (n + 1) * topk);
}
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
sub_result.mutable_distances() = distances;
sub_result.mutable_seg_offsets() = ids;
final_result.merge(sub_result);
}
sub_result->mutable_distances() = std::move(distances);
sub_result->mutable_seg_offsets() = std::move(ids);
return sub_result;
}

for (int n = 0; n < num_queries; ++n) {
ASSERT_EQ(ref_results[n].size(), topk);
template<class queue_type>
void
CheckSubSearchResult(const int64_t nq,
const int64_t topk,
SubSearchResult& search_result,
std::vector<queue_type>& result_ref) {
ASSERT_EQ(result_ref.size(), nq);
for (int n = 0; n < nq; ++n) {
ASSERT_EQ(result_ref[n].size(), topk);
for (int k = 0; k < topk; ++k) {
auto ref_x = ref_results[n].top();
ref_results[n].pop();
auto ref_x = result_ref[n].top();
result_ref[n].pop();
auto index = n * topk + topk - 1 - k;
auto id = final_result.get_seg_offsets()[index];
auto distance = final_result.get_distances()[index];
auto id = search_result.get_seg_offsets()[index];
auto distance = search_result.get_distances()[index];
ASSERT_EQ(id, ref_x);
ASSERT_EQ(distance, ref_x);
}
}
}

TEST(Reduce, SubSearchResultDesc) {
int64_t num_queries = 512;
int64_t topk = 32;
int64_t iteration = 50;
template<class queue_type>
void
TestSubSearchResultMerge(const knowhere::MetricType& metric_type) {
int64_t num_queries = 16;
int64_t topk = 10;
int64_t iteration = 10;
int64_t round_decimal = 3;
constexpr int64_t limit = 100000000L;
constexpr int64_t init_value = 0;
auto metric_type = knowhere::metric::IP;
using queue_type = std::priority_queue<int64_t, std::vector<int64_t>, std::greater<int64_t>>;

std::vector<queue_type> ref_results(num_queries);
for (auto& ref_result : ref_results) {
for (int i = 0; i < topk; ++i) {
ref_result.push(init_value);
}
}
std::default_random_engine e(42);
std::vector<queue_type> result_ref(num_queries);

SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
for (int i = 0; i < iteration; ++i) {
std::vector<int64_t> ids;
std::vector<float> distances;
SubSearchResultUniq sub_result = GenSubSearchResult(num_queries, topk, metric_type, round_decimal);
auto ids = sub_result->get_ids();
for (int n = 0; n < num_queries; ++n) {
for (int k = 0; k < topk; ++k) {
auto gen_x = e() % limit;
ref_results[n].push(gen_x);
ref_results[n].pop();
ids.push_back(gen_x);
distances.push_back(gen_x);
int64_t x = ids[n * topk + k];
result_ref[n].push(x);
if (result_ref[n].size() > topk) {
result_ref[n].pop();
}
}
std::sort(ids.begin() + n * topk, ids.begin() + n * topk + topk, std::greater<int64_t>());
std::sort(distances.begin() + n * topk, distances.begin() + n * topk + topk, std::greater<float>());
}
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
sub_result.mutable_distances() = distances;
sub_result.mutable_seg_offsets() = ids;
final_result.merge(sub_result);
final_result.merge(*sub_result);
}
CheckSubSearchResult<queue_type>(num_queries, topk, final_result, result_ref);
}

for (int n = 0; n < num_queries; ++n) {
ASSERT_EQ(ref_results[n].size(), topk);
for (int k = 0; k < topk; ++k) {
auto ref_x = ref_results[n].top();
ref_results[n].pop();
auto index = n * topk + topk - 1 - k;
auto id = final_result.get_seg_offsets()[index];
auto distance = final_result.get_distances()[index];
ASSERT_EQ(id, ref_x);
ASSERT_EQ(distance, ref_x);
}
}
TEST(Reduce, SubSearchResult) {
using queue_type_l2 = std::priority_queue<int64_t, std::vector<int64_t>, std::less<int64_t>>;
using queue_type_ip = std::priority_queue<int64_t, std::vector<int64_t>, std::greater<int64_t>>;
TestSubSearchResultMerge<queue_type_l2>(knowhere::metric::L2);
TestSubSearchResultMerge<queue_type_ip>(knowhere::metric::IP);
}

0 comments on commit 4ded453

Please sign in to comment.