forked from opensearch-project/k-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature to main branch] Add binary format support with HNSW method i…
…n Faiss Engine (opensearch-project#1829) * Add faiss custom patch to support search parameter in binary index (opensearch-project#1815) Signed-off-by: Heemin Kim <heemin@amazon.com> * Add binary format support with HNSW method in Faiss Engine (opensearch-project#1781) Signed-off-by: Heemin Kim <heemin@amazon.com> --------- Signed-off-by: Heemin Kim <heemin@amazon.com>
- Loading branch information
Showing
75 changed files
with
2,568 additions
and
430 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
294 changes: 294 additions & 0 deletions
294
jni/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
From 4d374aa47d4415cbda04b299788988f4ff6e5da0 Mon Sep 17 00:00:00 2001 | ||
From: Heemin Kim <heemin@amazon.com> | ||
Date: Wed, 10 Jul 2024 16:06:36 -0700 | ||
Subject: [PATCH] 0004-Custom-patch-to-support-binary-vector | ||
|
||
Signed-off-by: Heemin Kim <heemin@amazon.com> | ||
--- | ||
faiss/IndexBinaryHNSW.cpp | 59 +++++++++++++------ | ||
faiss/IndexBinaryIVF.cpp | 28 +++++++-- | ||
tests/test_id_grouper.cpp | 117 +++++++++++++++++++++++++++++++++++++- | ||
3 files changed, 179 insertions(+), 25 deletions(-) | ||
|
||
diff --git a/faiss/IndexBinaryHNSW.cpp b/faiss/IndexBinaryHNSW.cpp | ||
index f1bda08f..32627cb0 100644 | ||
--- a/faiss/IndexBinaryHNSW.cpp | ||
+++ b/faiss/IndexBinaryHNSW.cpp | ||
@@ -189,37 +189,62 @@ void IndexBinaryHNSW::train(idx_t n, const uint8_t* x) { | ||
is_trained = true; | ||
} | ||
|
||
+namespace { | ||
+template <class BlockResultHandler> | ||
+void hnsw_search( | ||
+ const IndexBinaryHNSW* index, | ||
+ idx_t n, | ||
+ const uint8_t* x, | ||
+ BlockResultHandler& bres, | ||
+ const SearchParameters* params_in) { | ||
+ const SearchParametersHNSW* params = nullptr; | ||
+ const HNSW& hnsw = index->hnsw; | ||
+ | ||
+ if (params_in) { | ||
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in); | ||
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); | ||
+ } | ||
+#pragma omp parallel | ||
+ { | ||
+ VisitedTable vt(index->ntotal); | ||
+ std::unique_ptr<DistanceComputer> dis(index->get_distance_computer()); | ||
+ typename BlockResultHandler::SingleResultHandler res(bres); | ||
+ | ||
+#pragma omp for | ||
+ for (idx_t i = 0; i < n; i++) { | ||
+ res.begin(i); | ||
+ dis->set_query((float*)(x + i * index->code_size)); | ||
+ hnsw.search(*dis, res, vt, params); | ||
+ res.end(); | ||
+ } | ||
+ } | ||
+} | ||
+ | ||
+} // anonymous namespace | ||
+ | ||
void IndexBinaryHNSW::search( | ||
idx_t n, | ||
const uint8_t* x, | ||
idx_t k, | ||
int32_t* distances, | ||
idx_t* labels, | ||
- const SearchParameters* params) const { | ||
- FAISS_THROW_IF_NOT_MSG( | ||
- !params, "search params not supported for this index"); | ||
+ const SearchParameters* params_in) const { | ||
FAISS_THROW_IF_NOT(k > 0); | ||
|
||
// we use the buffer for distances as float but convert them back | ||
// to int in the end | ||
float* distances_f = (float*)distances; | ||
|
||
- using RH = HeapBlockResultHandler<HNSW::C>; | ||
- RH bres(n, distances_f, labels, k); | ||
+ if (params_in && params_in->grp) { | ||
+ using RH = GroupedHeapBlockResultHandler<HNSW::C>; | ||
+ RH bres(n, distances_f, labels, k, params_in->grp); | ||
|
||
-#pragma omp parallel | ||
- { | ||
- VisitedTable vt(ntotal); | ||
- std::unique_ptr<DistanceComputer> dis(get_distance_computer()); | ||
- RH::SingleResultHandler res(bres); | ||
+ hnsw_search(this, n, x, bres, params_in); | ||
+ } else { | ||
+ using RH = HeapBlockResultHandler<HNSW::C>; | ||
+ RH bres(n, distances_f, labels, k); | ||
|
||
-#pragma omp for | ||
- for (idx_t i = 0; i < n; i++) { | ||
- res.begin(i); | ||
- dis->set_query((float*)(x + i * code_size)); | ||
- hnsw.search(*dis, res, vt); | ||
- res.end(); | ||
- } | ||
+ hnsw_search(this, n, x, bres, params_in); | ||
} | ||
|
||
#pragma omp parallel for | ||
diff --git a/faiss/IndexBinaryIVF.cpp b/faiss/IndexBinaryIVF.cpp | ||
index ab1b9fd8..de996df3 100644 | ||
--- a/faiss/IndexBinaryIVF.cpp | ||
+++ b/faiss/IndexBinaryIVF.cpp | ||
@@ -113,25 +113,41 @@ void IndexBinaryIVF::search( | ||
idx_t k, | ||
int32_t* distances, | ||
idx_t* labels, | ||
- const SearchParameters* params) const { | ||
- FAISS_THROW_IF_NOT_MSG( | ||
- !params, "search params not supported for this index"); | ||
+ const SearchParameters* params_in) const { | ||
FAISS_THROW_IF_NOT(k > 0); | ||
+ const IVFSearchParameters* params = nullptr; | ||
+ if (params_in) { | ||
+ params = dynamic_cast<const IVFSearchParameters*>(params_in); | ||
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); | ||
+ } | ||
+ const size_t nprobe_2 = std::min(nlist, params ? params->nprobe : this->nprobe); | ||
FAISS_THROW_IF_NOT(nprobe > 0); | ||
|
||
- const size_t nprobe_2 = std::min(nlist, this->nprobe); | ||
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]); | ||
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]); | ||
|
||
double t0 = getmillisecs(); | ||
- quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get()); | ||
+ quantizer->search( | ||
+ n, | ||
+ x, | ||
+ nprobe_2, | ||
+ coarse_dis.get(), | ||
+ idx.get(), | ||
+ params ? params->quantizer_params : nullptr); | ||
indexIVF_stats.quantization_time += getmillisecs() - t0; | ||
|
||
t0 = getmillisecs(); | ||
invlists->prefetch_lists(idx.get(), n * nprobe_2); | ||
|
||
search_preassigned( | ||
- n, x, k, idx.get(), coarse_dis.get(), distances, labels, false); | ||
+ n, | ||
+ x, | ||
+ k, | ||
+ idx.get(), | ||
+ coarse_dis.get(), | ||
+ distances, labels, | ||
+ false, | ||
+ params); | ||
indexIVF_stats.search_time += getmillisecs() - t0; | ||
} | ||
|
||
diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp | ||
index 6601795b..bd8ab5f9 100644 | ||
--- a/tests/test_id_grouper.cpp | ||
+++ b/tests/test_id_grouper.cpp | ||
@@ -14,10 +14,10 @@ | ||
#include <faiss/IndexIDMap.h> | ||
#include <faiss/MetricType.h> | ||
#include <faiss/impl/IDGrouper.h> | ||
+#include "faiss/IndexBinaryHNSW.h" | ||
|
||
// 64-bit int | ||
using idx_t = faiss::idx_t; | ||
- | ||
using namespace faiss; | ||
|
||
TEST(IdGrouper, get_group) { | ||
@@ -172,7 +172,58 @@ TEST(IdGrouper, bitmap_with_hnsw) { | ||
delete[] xb; | ||
} | ||
|
||
-TEST(IdGrouper, bitmap_with_hnswn_idmap) { | ||
+TEST(IdGrouper, bitmap_with_binary_hnsw) { | ||
+ int d = 16; // dimension | ||
+ int nb = 10; // database size | ||
+ | ||
+ std::vector<uint8_t> database(nb * (d / 8)); | ||
+ for (size_t i = 0; i < nb * (d / 8); i++) { | ||
+ database[i] = rand() % 0x100; | ||
+ } | ||
+ | ||
+ uint64_t bitmap[1] = {}; | ||
+ faiss::IDGrouperBitmap id_grouper(1, bitmap); | ||
+ for (int i = 0; i < nb; i++) { | ||
+ if (i % 2 == 1) { | ||
+ id_grouper.set_group(i); | ||
+ } | ||
+ } | ||
+ | ||
+ int k = 10; | ||
+ int m = 8; | ||
+ faiss::IndexBinary* index = | ||
+ new faiss::IndexBinaryHNSW(d, m); | ||
+ index->add(nb, database.data()); // add vectors to the index | ||
+ | ||
+ // search | ||
+ idx_t* I = new idx_t[k]; | ||
+ int32_t* D = new int32_t[k]; | ||
+ | ||
+ auto pSearchParameters = new faiss::SearchParametersHNSW(); | ||
+ pSearchParameters->grp = &id_grouper; | ||
+ | ||
+ index->search(1, database.data(), k, D, I, pSearchParameters); | ||
+ | ||
+ std::unordered_set<int> group_ids; | ||
+ ASSERT_EQ(0, I[0]); | ||
+ ASSERT_EQ(0, D[0]); | ||
+ group_ids.insert(id_grouper.get_group(I[0])); | ||
+ for (int j = 1; j < 5; j++) { | ||
+ ASSERT_NE(-1, I[j]); | ||
+ ASSERT_NE(std::numeric_limits<int32_t>::max(), D[j]); | ||
+ group_ids.insert(id_grouper.get_group(I[j])); | ||
+ } | ||
+ for (int j = 5; j < k; j++) { | ||
+ ASSERT_EQ(-1, I[j]); | ||
+ ASSERT_EQ(std::numeric_limits<int32_t>::max(), D[j]); | ||
+ } | ||
+ ASSERT_EQ(5, group_ids.size()); | ||
+ | ||
+ delete[] I; | ||
+ delete[] D; | ||
+} | ||
+ | ||
+TEST(IdGrouper, bitmap_with_hnsw_idmap) { | ||
int d = 1; // dimension | ||
int nb = 10; // database size | ||
|
||
@@ -239,3 +290,65 @@ TEST(IdGrouper, bitmap_with_hnswn_idmap) { | ||
delete[] D; | ||
delete[] xb; | ||
} | ||
+ | ||
+TEST(IdGrouper, bitmap_with_binary_hnsw_idmap) { | ||
+ int d = 16; // dimension | ||
+ int nb = 10; // database size | ||
+ | ||
+ std::vector<uint8_t> database(nb * (d / 8)); | ||
+ for (size_t i = 0; i < nb * (d / 8); i++) { | ||
+ database[i] = rand() % 0x100; | ||
+ } | ||
+ | ||
+ idx_t* xids = new idx_t[nb]; | ||
+ uint64_t bitmap[1] = {}; | ||
+ faiss::IDGrouperBitmap id_grouper(1, bitmap); | ||
+ int num_grp = 0; | ||
+ int grp_size = 2; | ||
+ int id_in_grp = 0; | ||
+ for (int i = 0; i < nb; i++) { | ||
+ xids[i] = i + num_grp; | ||
+ id_in_grp++; | ||
+ if (id_in_grp == grp_size) { | ||
+ id_grouper.set_group(i + num_grp + 1); | ||
+ num_grp++; | ||
+ id_in_grp = 0; | ||
+ } | ||
+ } | ||
+ | ||
+ int k = 10; | ||
+ int m = 8; | ||
+ | ||
+ faiss::IndexBinary* index = | ||
+ new faiss::IndexBinaryHNSW(d, m); | ||
+ faiss::IndexBinaryIDMap id_map = | ||
+ faiss::IndexBinaryIDMap(index); // add vectors to the index | ||
+ id_map.add_with_ids(nb, database.data(), xids); | ||
+ | ||
+ // search | ||
+ idx_t* I = new idx_t[k]; | ||
+ int32_t* D = new int32_t[k]; | ||
+ | ||
+ auto pSearchParameters = new faiss::SearchParametersHNSW(); | ||
+ pSearchParameters->grp = &id_grouper; | ||
+ | ||
+ id_map.search(1, database.data(), k, D, I, pSearchParameters); | ||
+ | ||
+ std::unordered_set<int> group_ids; | ||
+ ASSERT_EQ(0, I[0]); | ||
+ ASSERT_EQ(0, D[0]); | ||
+ group_ids.insert(id_grouper.get_group(I[0])); | ||
+ for (int j = 1; j < 5; j++) { | ||
+ ASSERT_NE(-1, I[j]); | ||
+ ASSERT_NE(std::numeric_limits<int32_t>::max(), D[j]); | ||
+ group_ids.insert(id_grouper.get_group(I[j])); | ||
+ } | ||
+ for (int j = 5; j < k; j++) { | ||
+ ASSERT_EQ(-1, I[j]); | ||
+ ASSERT_EQ(std::numeric_limits<int32_t>::max(), D[j]); | ||
+ } | ||
+ ASSERT_EQ(5, group_ids.size()); | ||
+ | ||
+ delete[] I; | ||
+ delete[] D; | ||
+} | ||
\ No newline at end of file | ||
-- | ||
2.39.3 (Apple Git-146) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.