Skip to content

Commit

Permalink
Python: filter elements with an optional filtering function (#417)
Browse files Browse the repository at this point in the history
* Add Python filter option for knn query.
* Implement review suggestions
* Removed template filter_func_t, add filter to brute force index and update tests (credits go to dyashuni)

Co-authored-by: Georgios Tsoukas <georgios.tsoukas@mgb.ch>
  • Loading branch information
gtsoukas and Georgios Tsoukas committed Nov 9, 2022
1 parent 687ca85 commit 983cea9
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 61 deletions.
67 changes: 37 additions & 30 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@ namespace {

using idx_t = hnswlib::labeltype;

bool pickIdsDivisibleByThree(unsigned int label_id) {
return label_id % 3 == 0;
}

bool pickIdsDivisibleBySeven(unsigned int label_id) {
return label_id % 7 == 0;
}
class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
unsigned int divisor = 1;
public:
PickDivisibleIds(unsigned int divisor): divisor(divisor) {
assert(divisor != 0);
}
bool operator()(idx_t label_id) {
return label_id % divisor == 0;
}
};

bool pickNothing(unsigned int label_id) {
return false;
}
class PickNothing: public hnswlib::BaseFilterFunctor {
public:
bool operator()(idx_t label_id) {
return false;
}
};

template<typename filter_func_t>
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) {
void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
Expand All @@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand All @@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_brute->searchKnn(p, k, &filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
Expand All @@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
Expand All @@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
delete alg_hnsw;
}

template<typename filter_func_t>
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
Expand All @@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
Expand All @@ -120,17 +124,17 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_brute->searchKnn(p, k, &filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}
Expand All @@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {

} // namespace

class CustomFilterFunctor: public hnswlib::FilterFunctor {
std::unordered_set<unsigned int> allowed_values;
class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
std::unordered_set<idx_t> allowed_values;

public:
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}
explicit CustomFilterFunctor(const std::unordered_set<idx_t>& values) : allowed_values(values) {}

bool operator()(unsigned int id) {
bool operator()(idx_t id) {
return allowed_values.count(id) != 0;
}
};
Expand All @@ -156,10 +160,13 @@ int main() {
std::cout << "Testing ..." << std::endl;

// some of the elements are filtered
PickDivisibleIds pickIdsDivisibleByThree(3);
test_some_filtering(pickIdsDivisibleByThree, 3, 17);
PickDivisibleIds pickIdsDivisibleBySeven(7);
test_some_filtering(pickIdsDivisibleBySeven, 7, 17);

// all of the elements are filtered
PickNothing pickNothing;
test_none_filtering(pickNothing, 17);

// functor style which can capture context
Expand Down
11 changes: 5 additions & 6 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <assert.h>

namespace hnswlib {
template<typename dist_t, typename filter_func_t = FilterFunctor>
class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
public:
char *data_;
size_t maxelements_;
Expand Down Expand Up @@ -98,15 +98,14 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
assert(k <= cur_element_count);
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if (is_filter_disabled || isIdAllowed(label)) {
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
Expand All @@ -115,7 +114,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if (is_filter_disabled || isIdAllowed(label)) {
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
Expand Down
13 changes: 6 additions & 7 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t, typename filter_func_t = FilterFunctor>
class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
public:
static const tableint max_update_element_locks = 65536;
static const unsigned char DELETE_MARK = 0x01;
Expand Down Expand Up @@ -268,7 +268,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {

template <bool has_deletions, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -277,8 +277,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -336,7 +335,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
_MM_HINT_T0); ////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id))))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down Expand Up @@ -1083,7 +1082,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down
20 changes: 9 additions & 11 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,11 @@ namespace hnswlib {
typedef size_t labeltype;

// This can be extended to store state for filtering (e.g. from a std::set)
struct FilterFunctor {
template<class...Args>
bool operator()(Args&&...) { return true; }
class BaseFilterFunctor {
public:
virtual bool operator()(hnswlib::labeltype id) { return true; }
};

static FilterFunctor allowAllIds;

template <typename T>
class pairGreater {
public:
Expand Down Expand Up @@ -157,27 +155,27 @@ class SpaceInterface {
virtual ~SpaceInterface() {}
};

template<typename dist_t, typename filter_func_t = FilterFunctor>
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label) = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype>>
searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0;
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const;
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;

virtual void saveIndex(const std::string &location) = 0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t, typename filter_func_t>
template<typename dist_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
filter_func_t& isIdAllowed) const {
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
BaseFilterFunctor* isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
Expand Down
42 changes: 35 additions & 7 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) {
}


class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
std::function<bool(hnswlib::labeltype)> filter;

public:
explicit CustomFilterFunctor(const std::function<bool(hnswlib::labeltype)>& f) {
filter = f;
}

bool operator()(hnswlib::labeltype id) {
return filter(id);
}
};


inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
if (buffer.ndim != 2 && buffer.ndim != 1) {
char msg[256];
Expand Down Expand Up @@ -573,7 +588,11 @@ class Index {
}


py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
py::object knnQuery_return_numpy(
py::object input,
size_t k = 1,
int num_threads = -1,
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype* data_numpy_l;
Expand All @@ -595,10 +614,13 @@ class Index {
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];

CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

if (normalize == false) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)items.data(row), k);
(void*)items.data(row), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand All @@ -618,7 +640,7 @@ class Index {
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));

std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)(norm_array.data() + start_idx), k);
(void*)(norm_array.data() + start_idx), k, p_idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand Down Expand Up @@ -785,7 +807,10 @@ class BFIndex {
}


py::object knnQuery_return_numpy(py::object input, size_t k = 1) {
py::object knnQuery_return_numpy(
py::object input,
size_t k = 1,
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype *data_numpy_l;
Expand All @@ -799,9 +824,12 @@ class BFIndex {
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];

CustomFilterFunctor idFilter(filter);
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;

for (size_t row = 0; row < rows; row++) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
(void *) items.data(row), k);
(void *) items.data(row), k, p_idFilter);
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
Expand Down Expand Up @@ -844,7 +872,7 @@ PYBIND11_PLUGIN(hnswlib) {
.def(py::init(&Index<float>::createFromIndex), py::arg("index"))
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &Index<float>::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none())
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1)
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_ids_list", &Index<float>::getIdsList)
Expand Down Expand Up @@ -899,7 +927,7 @@ PYBIND11_PLUGIN(hnswlib) {
py::class_<BFIndex<float>>(m, "BFIndex")
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1)
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none())
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))
Expand Down
Loading

0 comments on commit 983cea9

Please sign in to comment.