Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: filter elements with an optional filtering function #417

Merged
merged 4 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 26 additions & 9 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) {
}


class CustomFilterFunctor: public hnswlib::FilterFunctor {
std::function<int(int)> filter;
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved

public:
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
explicit CustomFilterFunctor(const std::function<bool(unsigned int)> &f) {
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
filter = f;
}

bool operator()(unsigned int id) {
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
return filter(id);
}
};


inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
if (buffer.ndim != 2 && buffer.ndim != 1) {
char msg[256];
Expand Down Expand Up @@ -127,7 +142,7 @@ inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_
}


template<typename dist_t, typename data_t = float>
template<typename dist_t, typename data_t=float, typename filter_func_t=CustomFilterFunctor>
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
class Index {
public:
static const int ser_version = 1; // serialization version
Expand All @@ -142,7 +157,7 @@ class Index {
bool normalize;
int num_threads_default;
hnswlib::labeltype cur_l;
hnswlib::HierarchicalNSW<dist_t>* appr_alg;
hnswlib::HierarchicalNSW<dist_t, filter_func_t>* appr_alg;
hnswlib::SpaceInterface<float>* l2space;


Expand Down Expand Up @@ -183,7 +198,7 @@ class Index {
throw std::runtime_error("The index is already initiated.");
}
cur_l = 0;
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, maxElements, M, efConstruction, random_seed);
appr_alg = new hnswlib::HierarchicalNSW<dist_t, filter_func_t>(l2space, maxElements, M, efConstruction, random_seed);
index_inited = true;
ep_added = false;
appr_alg->ef_ = default_ef;
Expand Down Expand Up @@ -213,7 +228,7 @@ class Index {
std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl;
delete appr_alg;
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, path_to_index, false, max_elements);
appr_alg = new hnswlib::HierarchicalNSW<dist_t, filter_func_t>(l2space, path_to_index, false, max_elements);
cur_l = appr_alg->cur_element_count;
index_inited = true;
}
Expand Down Expand Up @@ -468,7 +483,7 @@ class Index {
new_index->seed = d["seed"].cast<size_t>();

if (index_inited_) {
new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t>(
new_index->appr_alg = new hnswlib::HierarchicalNSW<dist_t, filter_func_t>(
new_index->l2space,
d["max_elements"].cast<size_t>(),
d["M"].cast<size_t>(),
Expand Down Expand Up @@ -573,7 +588,7 @@ class Index {
}


py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1, const std::function<bool(unsigned int)> &filter = nullptr) {
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype* data_numpy_l;
Expand All @@ -595,10 +610,12 @@ class Index {
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];

CustomFilterFunctor idFilter((filter != nullptr)?filter:[](unsigned int id){return true;});
gtsoukas marked this conversation as resolved.
Show resolved Hide resolved
dyashuni marked this conversation as resolved.
Show resolved Hide resolved

if (normalize == false) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)items.data(row), k);
(void*)items.data(row), k, idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand All @@ -618,7 +635,7 @@ class Index {
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));

std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void*)(norm_array.data() + start_idx), k);
(void*)(norm_array.data() + start_idx), k, idFilter);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
Expand Down Expand Up @@ -844,7 +861,7 @@ PYBIND11_PLUGIN(hnswlib) {
.def(py::init(&Index<float>::createFromIndex), py::arg("index"))
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &Index<float>::init_new_index, py::arg("max_elements"), py::arg("M") = 16, py::arg("ef_construction") = 200, py::arg("random_seed") = 100)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("num_threads") = -1, py::arg("filter") = py::none())
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads") = -1)
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_ids_list", &Index<float>::getIdsList)
Expand Down
51 changes: 51 additions & 0 deletions python_bindings/tests/bindings_test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import unittest

import numpy as np

import hnswlib


class RandomSelfTestCase(unittest.TestCase):
def testRandomSelf(self):

dim = 16
num_elements = 10000

# Generating sample data
data = np.float32(np.random.random((num_elements, dim)))

# Declaring index
p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip

# Initiating index
# max_elements - the maximum number of elements, should be known beforehand
# (probably will be made optional in the future)
#
# ef_construction - controls index search speed/build speed tradeoff
# M - is tightly connected with internal dimensionality of the data
# strongly affects the memory consumption

p.init_index(max_elements=num_elements, ef_construction=100, M=16)

# Controlling the recall by setting ef:
# higher ef leads to better accuracy, but slower search
p.set_ef(10)

p.set_num_threads(4) # by default using all available cores

print("Adding %d elements" % (len(data)))
p.add_items(data)

# Query the elements for themselves and measure recall:
labels, distances = p.knn_query(data, k=1)
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3)

print("Querying only even elements")
# Query the even elements for themselves and measure recall:
filter_function = lambda id: id%2 == 0
labels, distances = p.knn_query(data, k=1, filter=filter_function)
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3)
# Verify that there are onle even elements:
self.assertTrue(np.max(np.mod(labels, 2)) == 0)