From b6b338e661c245dc155b15b87b429c8b5c85f8cc Mon Sep 17 00:00:00 2001 From: uestc-lfs Date: Sun, 28 Jun 2020 00:05:24 +0800 Subject: [PATCH 1/4] 1. Replace the template interface searchKnn with virtual interface 2. add asser.h, or it will not compile --- hnswlib/hnswalg.h | 28 ++++++++++++---------------- hnswlib/hnswlib.h | 5 ++--- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 97bdcd18..2db735d9 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -8,6 +8,7 @@ #include #include +#include namespace hnswlib { typedef unsigned int tableint; @@ -406,7 +407,7 @@ namespace hnswlib { top_candidates.pop(); } - tableint next_closest_entry_point = selectedNeighbors[0]; + tableint next_closest_entry_point = selectedNeighbors.back(); { linklistsizeint *ll_cur; @@ -1156,24 +1157,19 @@ namespace hnswlib { return result; }; - template - std::vector> - searchKnn(const void* query_data, size_t k, Comp comp) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn(query_data, k); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); + int searchKnn(const void* x, + int k, labeltype* labels, dist_t* dists = nullptr) const override { + if (labels == nullptr) return -1; + auto ret = searchKnn(x, k); + for (int i = k - 1; i >= 0; --i) { + if (dists) + dists[i] = ret.top().first; + labels[i] = ret.top().second; } - - std::sort(result.begin(), result.end(), comp); - - return result; + return 0; } + void checkIntegrity(){ int connections_checked=0; std::vector inbound_connections_num(cur_element_count,0); diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index c26f80b5..6ef54495 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -71,9 +71,8 @@ namespace hnswlib { public: virtual void addPoint(const void *datapoint, labeltype label)=0; virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; - template - std::vector> searchKnn(const void*, size_t, Comp) { - } + virtual int searchKnn(const void* x, + int k, labeltype* labels, dist_t* dists) const = 0; virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } From 898718801e91c8f04ef734a6dbecec0909fed3b2 Mon Sep 17 00:00:00 2001 From: uestc-lfs Date: Sun, 28 Jun 2020 00:08:09 +0800 Subject: [PATCH 2/4] minor fix --- hnswlib/hnswalg.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 2db735d9..342e4ad5 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -407,7 +407,7 @@ namespace hnswlib { top_candidates.pop(); } - tableint next_closest_entry_point = selectedNeighbors.back(); + tableint next_closest_entry_point = selectedNeighbors[0]; { linklistsizeint *ll_cur; From 9fe639d71f3dc3dd793723395a7510258bf698bb Mon Sep 17 00:00:00 2001 From: uestc-lfs Date: Sun, 13 Dec 2020 00:22:59 +0800 Subject: [PATCH 3/4] fix interface --- examples/searchKnnCloserFirst_test.cpp | 84 ++++++++++++++++++++++++++ hnswlib/bruteforce.h | 18 ------ hnswlib/hnswalg.h | 13 ---- hnswlib/hnswlib.h | 25 +++++++- 4 files changed, 107 insertions(+), 33 deletions(-) create mode 100644 examples/searchKnnCloserFirst_test.cpp diff --git a/examples/searchKnnCloserFirst_test.cpp b/examples/searchKnnCloserFirst_test.cpp new file mode 100644 index 00000000..cc1392c8 --- /dev/null +++ b/examples/searchKnnCloserFirst_test.cpp @@ -0,0 +1,84 @@ +// This is a test file for testing the interface +// >>> virtual std::vector> +// >>> searchKnnCloserFirst(const void* query_data, size_t k) const; +// of class AlgorithmInterface + +#include "../hnswlib/hnswlib.h" + +#include + +#include +#include + +namespace +{ + +using idx_t = hnswlib::labeltype; + +void test() { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + alg_brute->addPoint(data.data() + d * i, i); + alg_hnsw->addPoint(data.data() + d * i, i); + } + + // test searchKnnCloserFirst of BruteforceSearch + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k); + auto res = alg_brute->searchKnnCloserFirst(p, k); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + gd.pop(); + } + } + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k); + auto res = alg_hnsw->searchKnnCloserFirst(p, k); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + gd.pop(); + } + } + + delete alg_brute; + delete alg_hnsw; +} + +} // namespace + +int main() { + std::cout << "Testing ..." << std::endl; + test(); + std::cout << "Test ok" << std::endl; + + return 0; +} diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 5b1bd655..24260400 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -111,24 +111,6 @@ namespace hnswlib { return topResults; }; - template - std::vector> - searchKnn(const void* query_data, size_t k, Comp comp) { - std::vector> result; - if (cur_element_count == 0) return result; - - auto ret = searchKnn(query_data, k); - - while (!ret.empty()) { - result.push_back(ret.top()); - ret.pop(); - } - - std::sort(result.begin(), result.end(), comp); - - return result; - } - void saveIndex(const std::string &location) { std::ofstream output(location, std::ios::binary); std::streampos position; diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 13df46cb..025b55c1 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -1157,19 +1157,6 @@ namespace hnswlib { return result; }; - int searchKnn(const void* x, - int k, labeltype* labels, dist_t* dists = nullptr) const override { - if (labels == nullptr) return -1; - auto ret = searchKnn(x, k); - for (int i = k - 1; i >= 0; --i) { - if (dists) - dists[i] = ret.top().first; - labels[i] = ret.top().second; - } - return 0; - } - - void checkIntegrity(){ int connections_checked=0; std::vector inbound_connections_num(cur_element_count,0); diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 6ef54495..9409c388 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -71,13 +71,34 @@ namespace hnswlib { public: virtual void addPoint(const void *datapoint, labeltype label)=0; virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; - virtual int searchKnn(const void* x, - int k, labeltype* labels, dist_t* dists) const = 0; + + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k) const; + virtual void saveIndex(const std::string &location)=0; virtual ~AlgorithmInterface(){ } }; + template + std::vector> + AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; + } } From 21c1ad76640201a7bc1d2753cd562dd5979e86e8 Mon Sep 17 00:00:00 2001 From: uestc-lfs Date: Sun, 13 Dec 2020 14:07:10 +0800 Subject: [PATCH 4/4] minor fix --- CMakeLists.txt | 2 ++ hnswlib/hnswalg.h | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ebee6e6c..31935e0e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,4 +23,6 @@ endif() add_executable(test_updates examples/updates_test.cpp) +add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) + target_link_libraries(main sift_test) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 025b55c1..a2f72dc7 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -9,8 +9,6 @@ #include #include -#include - namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint;