Skip to content

Commit

Permalink
[feature to main branch] Add binary format support with HNSW method i…
Browse files Browse the repository at this point in the history
…n Faiss Engine (opensearch-project#1829)

* Add faiss custom patch to support search parameter in binary index (opensearch-project#1815)

Signed-off-by: Heemin Kim <heemin@amazon.com>

* Add binary format support with HNSW method in Faiss Engine (opensearch-project#1781)

Signed-off-by: Heemin Kim <heemin@amazon.com>

---------

Signed-off-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
heemin32 committed Jul 16, 2024
1 parent 31a4d3e commit fe1d86f
Show file tree
Hide file tree
Showing 75 changed files with 2,568 additions and 430 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783)
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
### Enhancements
### Bug Fixes
* Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804)
Expand Down
3 changes: 2 additions & 1 deletion jni/cmake/init-faiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ if (NOT EXISTS ${FAISS_REPO_DIR})
endif ()

# Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu.
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch 0004-Custom-patch-to-support-binary-vector.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)

# If it exists, apply patches
if (EXISTS ${PATCH_FILE})
message(STATUS "Applying custom patches.")
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
if(RESULT_CODE)
message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}")
endif()
Expand Down
2 changes: 1 addition & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace knn_jni {
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
void Free(jlong indexPointer, jboolean isBinaryIndexJ);

// Free shared index state in memory at shareIndexStatePointerJ
void FreeSharedIndexState(jlong shareIndexStatePointerJ);
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: free
* Signature: (J)V
* Signature: (JZ)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free
(JNIEnv *, jclass, jlong);
(JNIEnv *, jclass, jlong, jboolean);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
294 changes: 294 additions & 0 deletions jni/patches/faiss/0004-Custom-patch-to-support-binary-vector.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
From 4d374aa47d4415cbda04b299788988f4ff6e5da0 Mon Sep 17 00:00:00 2001
From: Heemin Kim <heemin@amazon.com>
Date: Wed, 10 Jul 2024 16:06:36 -0700
Subject: [PATCH] 0004-Custom-patch-to-support-binary-vector

Signed-off-by: Heemin Kim <heemin@amazon.com>
---
faiss/IndexBinaryHNSW.cpp | 59 +++++++++++++------
faiss/IndexBinaryIVF.cpp | 28 +++++++--
tests/test_id_grouper.cpp | 117 +++++++++++++++++++++++++++++++++++++-
3 files changed, 179 insertions(+), 25 deletions(-)

diff --git a/faiss/IndexBinaryHNSW.cpp b/faiss/IndexBinaryHNSW.cpp
index f1bda08f..32627cb0 100644
--- a/faiss/IndexBinaryHNSW.cpp
+++ b/faiss/IndexBinaryHNSW.cpp
@@ -189,37 +189,62 @@ void IndexBinaryHNSW::train(idx_t n, const uint8_t* x) {
is_trained = true;
}

+namespace {
+template <class BlockResultHandler>
+void hnsw_search(
+ const IndexBinaryHNSW* index,
+ idx_t n,
+ const uint8_t* x,
+ BlockResultHandler& bres,
+ const SearchParameters* params_in) {
+ const SearchParametersHNSW* params = nullptr;
+ const HNSW& hnsw = index->hnsw;
+
+ if (params_in) {
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
+ }
+#pragma omp parallel
+ {
+ VisitedTable vt(index->ntotal);
+ std::unique_ptr<DistanceComputer> dis(index->get_distance_computer());
+ typename BlockResultHandler::SingleResultHandler res(bres);
+
+#pragma omp for
+ for (idx_t i = 0; i < n; i++) {
+ res.begin(i);
+ dis->set_query((float*)(x + i * index->code_size));
+ hnsw.search(*dis, res, vt, params);
+ res.end();
+ }
+ }
+}
+
+} // anonymous namespace
+
void IndexBinaryHNSW::search(
idx_t n,
const uint8_t* x,
idx_t k,
int32_t* distances,
idx_t* labels,
- const SearchParameters* params) const {
- FAISS_THROW_IF_NOT_MSG(
- !params, "search params not supported for this index");
+ const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);

// we use the buffer for distances as float but convert them back
// to int in the end
float* distances_f = (float*)distances;

- using RH = HeapBlockResultHandler<HNSW::C>;
- RH bres(n, distances_f, labels, k);
+ if (params_in && params_in->grp) {
+ using RH = GroupedHeapBlockResultHandler<HNSW::C>;
+ RH bres(n, distances_f, labels, k, params_in->grp);

-#pragma omp parallel
- {
- VisitedTable vt(ntotal);
- std::unique_ptr<DistanceComputer> dis(get_distance_computer());
- RH::SingleResultHandler res(bres);
+ hnsw_search(this, n, x, bres, params_in);
+ } else {
+ using RH = HeapBlockResultHandler<HNSW::C>;
+ RH bres(n, distances_f, labels, k);

-#pragma omp for
- for (idx_t i = 0; i < n; i++) {
- res.begin(i);
- dis->set_query((float*)(x + i * code_size));
- hnsw.search(*dis, res, vt);
- res.end();
- }
+ hnsw_search(this, n, x, bres, params_in);
}

#pragma omp parallel for
diff --git a/faiss/IndexBinaryIVF.cpp b/faiss/IndexBinaryIVF.cpp
index ab1b9fd8..de996df3 100644
--- a/faiss/IndexBinaryIVF.cpp
+++ b/faiss/IndexBinaryIVF.cpp
@@ -113,25 +113,41 @@ void IndexBinaryIVF::search(
idx_t k,
int32_t* distances,
idx_t* labels,
- const SearchParameters* params) const {
- FAISS_THROW_IF_NOT_MSG(
- !params, "search params not supported for this index");
+ const SearchParameters* params_in) const {
FAISS_THROW_IF_NOT(k > 0);
+ const IVFSearchParameters* params = nullptr;
+ if (params_in) {
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
+ FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
+ }
+ const size_t nprobe_2 = std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);

- const size_t nprobe_2 = std::min(nlist, this->nprobe);
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);

double t0 = getmillisecs();
- quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
+ quantizer->search(
+ n,
+ x,
+ nprobe_2,
+ coarse_dis.get(),
+ idx.get(),
+ params ? params->quantizer_params : nullptr);
indexIVF_stats.quantization_time += getmillisecs() - t0;

t0 = getmillisecs();
invlists->prefetch_lists(idx.get(), n * nprobe_2);

search_preassigned(
- n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
+ n,
+ x,
+ k,
+ idx.get(),
+ coarse_dis.get(),
+ distances, labels,
+ false,
+ params);
indexIVF_stats.search_time += getmillisecs() - t0;
}

diff --git a/tests/test_id_grouper.cpp b/tests/test_id_grouper.cpp
index 6601795b..bd8ab5f9 100644
--- a/tests/test_id_grouper.cpp
+++ b/tests/test_id_grouper.cpp
@@ -14,10 +14,10 @@
#include <faiss/IndexIDMap.h>
#include <faiss/MetricType.h>
#include <faiss/impl/IDGrouper.h>
+#include "faiss/IndexBinaryHNSW.h"

// 64-bit int
using idx_t = faiss::idx_t;
-
using namespace faiss;

TEST(IdGrouper, get_group) {
@@ -172,7 +172,58 @@ TEST(IdGrouper, bitmap_with_hnsw) {
delete[] xb;
}

-TEST(IdGrouper, bitmap_with_hnswn_idmap) {
+TEST(IdGrouper, bitmap_with_binary_hnsw) {
+ int d = 16; // dimension
+ int nb = 10; // database size
+
+ std::vector<uint8_t> database(nb * (d / 8));
+ for (size_t i = 0; i < nb * (d / 8); i++) {
+ database[i] = rand() % 0x100;
+ }
+
+ uint64_t bitmap[1] = {};
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
+ for (int i = 0; i < nb; i++) {
+ if (i % 2 == 1) {
+ id_grouper.set_group(i);
+ }
+ }
+
+ int k = 10;
+ int m = 8;
+ faiss::IndexBinary* index =
+ new faiss::IndexBinaryHNSW(d, m);
+ index->add(nb, database.data()); // add vectors to the index
+
+ // search
+ idx_t* I = new idx_t[k];
+ int32_t* D = new int32_t[k];
+
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
+ pSearchParameters->grp = &id_grouper;
+
+ index->search(1, database.data(), k, D, I, pSearchParameters);
+
+ std::unordered_set<int> group_ids;
+ ASSERT_EQ(0, I[0]);
+ ASSERT_EQ(0, D[0]);
+ group_ids.insert(id_grouper.get_group(I[0]));
+ for (int j = 1; j < 5; j++) {
+ ASSERT_NE(-1, I[j]);
+ ASSERT_NE(std::numeric_limits<int32_t>::max(), D[j]);
+ group_ids.insert(id_grouper.get_group(I[j]));
+ }
+ for (int j = 5; j < k; j++) {
+ ASSERT_EQ(-1, I[j]);
+ ASSERT_EQ(std::numeric_limits<int32_t>::max(), D[j]);
+ }
+ ASSERT_EQ(5, group_ids.size());
+
+ delete[] I;
+ delete[] D;
+}
+
+TEST(IdGrouper, bitmap_with_hnsw_idmap) {
int d = 1; // dimension
int nb = 10; // database size

@@ -239,3 +290,65 @@ TEST(IdGrouper, bitmap_with_hnswn_idmap) {
delete[] D;
delete[] xb;
}
+
+TEST(IdGrouper, bitmap_with_binary_hnsw_idmap) {
+ int d = 16; // dimension
+ int nb = 10; // database size
+
+ std::vector<uint8_t> database(nb * (d / 8));
+ for (size_t i = 0; i < nb * (d / 8); i++) {
+ database[i] = rand() % 0x100;
+ }
+
+ idx_t* xids = new idx_t[nb];
+ uint64_t bitmap[1] = {};
+ faiss::IDGrouperBitmap id_grouper(1, bitmap);
+ int num_grp = 0;
+ int grp_size = 2;
+ int id_in_grp = 0;
+ for (int i = 0; i < nb; i++) {
+ xids[i] = i + num_grp;
+ id_in_grp++;
+ if (id_in_grp == grp_size) {
+ id_grouper.set_group(i + num_grp + 1);
+ num_grp++;
+ id_in_grp = 0;
+ }
+ }
+
+ int k = 10;
+ int m = 8;
+
+ faiss::IndexBinary* index =
+ new faiss::IndexBinaryHNSW(d, m);
+ faiss::IndexBinaryIDMap id_map =
+ faiss::IndexBinaryIDMap(index); // add vectors to the index
+ id_map.add_with_ids(nb, database.data(), xids);
+
+ // search
+ idx_t* I = new idx_t[k];
+ int32_t* D = new int32_t[k];
+
+ auto pSearchParameters = new faiss::SearchParametersHNSW();
+ pSearchParameters->grp = &id_grouper;
+
+ id_map.search(1, database.data(), k, D, I, pSearchParameters);
+
+ std::unordered_set<int> group_ids;
+ ASSERT_EQ(0, I[0]);
+ ASSERT_EQ(0, D[0]);
+ group_ids.insert(id_grouper.get_group(I[0]));
+ for (int j = 1; j < 5; j++) {
+ ASSERT_NE(-1, I[j]);
+ ASSERT_NE(std::numeric_limits<int32_t>::max(), D[j]);
+ group_ids.insert(id_grouper.get_group(I[j]));
+ }
+ for (int j = 5; j < k; j++) {
+ ASSERT_EQ(-1, I[j]);
+ ASSERT_EQ(std::numeric_limits<int32_t>::max(), D[j]);
+ }
+ ASSERT_EQ(5, group_ids.size());
+
+ delete[] I;
+ delete[] D;
+}
\ No newline at end of file
--
2.39.3 (Apple Git-146)

13 changes: 10 additions & 3 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
return results;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
void knn_jni::faiss_wrapper::Free(jlong indexPointer, jboolean isBinaryIndexJ) {
bool isBinaryIndex = static_cast<bool>(isBinaryIndexJ);
if (isBinaryIndex) {
auto *indexWrapper = reinterpret_cast<faiss::IndexBinary*>(indexPointer);
delete indexWrapper;
}
else {
auto *indexWrapper = reinterpret_cast<faiss::Index*>(indexPointer);
delete indexWrapper;
}
}

void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) {
Expand Down
4 changes: 2 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBin

}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ)
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ, jboolean isBinaryIndexJ)
{
try {
return knn_jni::faiss_wrapper::Free(indexPointerJ);
return knn_jni::faiss_wrapper::Free(indexPointerJ, isBinaryIndexJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
16 changes: 15 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,21 @@ TEST(FaissFreeTest, BasicAssertions) {
test_util::FaissCreateIndex(dim, method, metricType));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex));
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_FALSE);
}


TEST(FaissBinaryFreeTest, BasicAssertions) {
// Define the data
int dim = 8;
std::string method = "BHNSW32";

// Create the index
faiss::IndexBinary *createdIndex(
test_util::FaissCreateBinaryIndex(dim, method));

// Free created index --> memory check should catch failure
knn_jni::faiss_wrapper::Free(reinterpret_cast<jlong>(createdIndex), JNI_TRUE);
}

TEST(FaissInitLibraryTest, BasicAssertions) {
Expand Down
Loading

0 comments on commit fe1d86f

Please sign in to comment.