diff --git a/examples/searchKnnWithFilter_test.cpp b/examples/searchKnnWithFilter_test.cpp index 4aee49b0..6102323c 100644 --- a/examples/searchKnnWithFilter_test.cpp +++ b/examples/searchKnnWithFilter_test.cpp @@ -11,20 +11,25 @@ namespace { using idx_t = hnswlib::labeltype; -bool pickIdsDivisibleByThree(unsigned int label_id) { - return label_id % 3 == 0; -} - -bool pickIdsDivisibleBySeven(unsigned int label_id) { - return label_id % 7 == 0; -} +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(idx_t label_id) { + return label_id % divisor == 0; + } +}; -bool pickNothing(unsigned int label_id) { - return false; -} +class PickNothing: public hnswlib::BaseFilterFunctor { + public: + bool operator()(idx_t label_id) { + return false; + } +}; -template -void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { +void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe // test searchKnnCloserFirst of BruteforceSearch with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_brute->searchKnn(p, k, filter_func); - auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); size_t t = gd.size(); while (!gd.empty()) { @@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe // test searchKnnCloserFirst of hnsw with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_hnsw->searchKnn(p, k, filter_func); - auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); size_t t = gd.size(); while (!gd.empty()) { @@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe delete alg_hnsw; } -template -void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { +void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) { int d = 4; idx_t n = 100; idx_t nq = 10; @@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { } hnswlib::L2Space space(d); - hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); - hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); for (size_t i = 0; i < n; ++i) { // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs @@ -120,8 +124,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { // test searchKnnCloserFirst of BruteforceSearch with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_brute->searchKnn(p, k, filter_func); - auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); assert(0 == gd.size()); } @@ -129,8 +133,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { // test searchKnnCloserFirst of hnsw with filtering for (size_t j = 0; j < nq; ++j) { const void* p = query.data() + j * d; - auto gd = alg_hnsw->searchKnn(p, k, filter_func); - auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); assert(gd.size() == res.size()); assert(0 == gd.size()); } @@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { } // namespace -class CustomFilterFunctor: public hnswlib::FilterFunctor { - std::unordered_set allowed_values; +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::unordered_set allowed_values; public: - explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} + explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} - bool operator()(unsigned int id) { + bool operator()(idx_t id) { return allowed_values.count(id) != 0; } }; @@ -156,10 +160,13 @@ int main() { std::cout << "Testing ..." << std::endl; // some of the elements are filtered + PickDivisibleIds pickIdsDivisibleByThree(3); test_some_filtering(pickIdsDivisibleByThree, 3, 17); + PickDivisibleIds pickIdsDivisibleBySeven(7); test_some_filtering(pickIdsDivisibleBySeven, 7, 17); // all of the elements are filtered + PickNothing pickNothing; test_none_filtering(pickNothing, 17); // functor style which can capture context diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index ec2ef350..21130090 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -6,8 +6,8 @@ #include namespace hnswlib { -template -class BruteforceSearch : public AlgorithmInterface { +template +class BruteforceSearch : public AlgorithmInterface { public: char *data_; size_t maxelements_; @@ -98,15 +98,14 @@ class BruteforceSearch : public AlgorithmInterface { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { assert(k <= cur_element_count); std::priority_queue> topResults; if (cur_element_count == 0) return topResults; - bool is_filter_disabled = std::is_same::value; for (int i = 0; i < k; i++) { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); - if (is_filter_disabled || isIdAllowed(label)) { + if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.push(std::pair(dist, label)); } } @@ -115,7 +114,7 @@ class BruteforceSearch : public AlgorithmInterface { dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); if (dist <= lastdist) { labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); - if (is_filter_disabled || isIdAllowed(label)) { + if ((!isIdAllowed) || (*isIdAllowed)(label)) { topResults.push(std::pair(dist, label)); } if (topResults.size() > k) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 32b173e1..25995134 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -13,8 +13,8 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; -template -class HierarchicalNSW : public AlgorithmInterface { +template +class HierarchicalNSW : public AlgorithmInterface { public: static const tableint max_update_element_locks = 65536; static const unsigned char DELETE_MARK = 0x01; @@ -268,7 +268,7 @@ class HierarchicalNSW : public AlgorithmInterface { template std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); vl_type *visited_array = vl->mass; vl_type visited_array_tag = vl->curV; @@ -277,8 +277,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; - bool is_filter_disabled = std::is_same::value; - if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); lowerBound = dist; top_candidates.emplace(dist, ep_id); @@ -336,7 +335,7 @@ class HierarchicalNSW : public AlgorithmInterface { _MM_HINT_T0); //////////////////////// #endif - if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) top_candidates.emplace(dist, candidate_id); if (top_candidates.size() > ef) @@ -1083,7 +1082,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue> - searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const { + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { std::priority_queue> result; if (cur_element_count == 0) return result; diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index f11fd373..72c955dc 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -116,13 +116,11 @@ namespace hnswlib { typedef size_t labeltype; // This can be extended to store state for filtering (e.g. from a std::set) -struct FilterFunctor { - template - bool operator()(Args&&...) { return true; } +class BaseFilterFunctor { + public: + virtual bool operator()(hnswlib::labeltype id) { return true; } }; -static FilterFunctor allowAllIds; - template class pairGreater { public: @@ -157,27 +155,27 @@ class SpaceInterface { virtual ~SpaceInterface() {} }; -template +template class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label) = 0; virtual std::priority_queue> - searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0; + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; // Return k nearest neighbor in the order of closer fist virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const; + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; virtual void saveIndex(const std::string &location) = 0; virtual ~AlgorithmInterface(){ } }; -template +template std::vector> -AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, - filter_func_t& isIdAllowed) const { +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + BaseFilterFunctor* isIdAllowed) const { std::vector> result; // here searchKnn returns the result in the order of further first diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 85751c0b..3da8dbba 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) { } +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::function filter; + + public: + explicit CustomFilterFunctor(const std::function& f) { + filter = f; + } + + bool operator()(hnswlib::labeltype id) { + return filter(id); + } +}; + + inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { if (buffer.ndim != 2 && buffer.ndim != 1) { char msg[256]; @@ -573,7 +588,11 @@ class Index { } - py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + int num_threads = -1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype* data_numpy_l; @@ -595,10 +614,13 @@ class Index { data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + if (normalize == false) { ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { std::priority_queue> result = appr_alg->searchKnn( - (void*)items.data(row), k); + (void*)items.data(row), k, p_idFilter); if (result.size() != k) throw std::runtime_error( "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); @@ -618,7 +640,7 @@ class Index { normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); std::priority_queue> result = appr_alg->searchKnn( - (void*)(norm_array.data() + start_idx), k); + (void*)(norm_array.data() + start_idx), k, p_idFilter); if (result.size() != k) throw std::runtime_error( "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); @@ -785,7 +807,10 @@ class BFIndex { } - py::object knnQuery_return_numpy(py::object input, size_t k = 1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype *data_numpy_l; @@ -799,9 +824,12 @@ class BFIndex { data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + for (size_t row = 0; row < rows; row++) { std::priority_queue> result = alg->searchKnn( - (void *) items.data(row), k); + (void *) items.data(row), k, p_idFilter); for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -844,7 +872,7 @@ PYBIND11_PLUGIN(hnswlib) { .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1) + .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none()) .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) @@ -899,7 +927,7 @@ PYBIND11_PLUGIN(hnswlib) { py::class_>(m, "BFIndex") .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &BFIndex::init_new_index, py::arg("max_elements")) - .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1) + .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none()) .def("add_items", &BFIndex::addItems, py::arg("data"), py::arg("ids") = py::none()) .def("delete_vector", &BFIndex::deleteVector, py::arg("label")) .def("save_index", &BFIndex::saveIndex, py::arg("path_to_index")) diff --git a/python_bindings/tests/bindings_test_filter.py b/python_bindings/tests/bindings_test_filter.py new file mode 100644 index 00000000..a0715d7c --- /dev/null +++ b/python_bindings/tests/bindings_test_filter.py @@ -0,0 +1,56 @@ +import os +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + + dim = 16 + num_elements = 10000 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initiating index + # max_elements - the maximum number of elements, should be known beforehand + # (probably will be made optional in the future) + # + # ef_construction - controls index search speed/build speed tradeoff + # M - is tightly connected with internal dimensionality of the data + # strongly affects the memory consumption + + hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(10) + + hnsw_index.set_num_threads(4) # by default using all available cores + + print("Adding %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + # Query the elements for themselves and measure recall: + labels, distances = hnsw_index.knn_query(data, k=1) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3) + + print("Querying only even elements") + # Query the even elements for themselves and measure recall: + filter_function = lambda id: id%2 == 0 + labels, distances = hnsw_index.knn_query(data, k=1, filter=filter_function) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) + # Verify that there are onle even elements: + self.assertTrue(np.max(np.mod(labels, 2)) == 0) + + labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) + self.assertEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5)