Skip to content
Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
386 lines (329 sloc) 13.8 KB
#include <iostream>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include "../hnswlib/hnswlib.h"
#include <thread>
namespace py = pybind11;
/*
* replacement for the openmp '#pragma omp parallel for' directive
* only handles a subset of functionality (no reductions etc)
* Process ids from start (inclusive) to end (EXCLUSIVE)
*
* The method is borrowed from nmslib
*/
template<class Function>
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
if (numThreads <= 0) {
numThreads = std::thread::hardware_concurrency();
}
if (numThreads == 1) {
for (size_t id = start; id < end; id++) {
fn(id, 0);
}
} else {
std::vector<std::thread> threads;
std::atomic<size_t> current(start);
// keep track of exceptions in threads
// https://stackoverflow.com/a/32428427/1713196
std::exception_ptr lastException = nullptr;
std::mutex lastExceptMutex;
for (size_t threadId = 0; threadId < numThreads; ++threadId) {
threads.push_back(std::thread([&, threadId] {
while (true) {
size_t id = current.fetch_add(1);
if ((id >= end)) {
break;
}
try {
fn(id, threadId);
} catch (...) {
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
lastException = std::current_exception();
/*
* This will work even when current is the largest value that
* size_t can fit, because fetch_add returns the previous value
* before the increment (what will result in overflow
* and produce 0 instead of current + 1).
*/
current = end;
break;
}
}
}));
}
for (auto &thread : threads) {
thread.join();
}
if (lastException) {
std::rethrow_exception(lastException);
}
}
}
template<typename dist_t, typename data_t=float>
class Index {
public:
Index(const std::string &space_name, const int dim) :
space_name(space_name), dim(dim) {
normalize=false;
if(space_name=="l2") {
l2space = new hnswlib::L2Space(dim);
}
else if(space_name=="ip") {
l2space = new hnswlib::InnerProductSpace(dim);
}
else if(space_name=="cosine") {
l2space = new hnswlib::InnerProductSpace(dim);
normalize=true;
}
appr_alg = NULL;
ep_added = true;
index_inited = false;
num_threads_default = std::thread::hardware_concurrency();
}
void init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) {
if (appr_alg) {
throw new std::runtime_error("The index is already initiated.");
}
cur_l = 0;
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, maxElements, M, efConstruction, random_seed);
index_inited = true;
ep_added = false;
}
void set_ef(size_t ef) {
appr_alg->ef_ = ef;
}
void set_num_threads(int num_threads) {
this->num_threads_default = num_threads;
}
void saveIndex(const std::string &path_to_index) {
appr_alg->saveIndex(path_to_index);
}
void loadIndex(const std::string &path_to_index, size_t max_elements) {
if (appr_alg) {
std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated.";
delete appr_alg;
}
appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, path_to_index, false, max_elements);
cur_l = appr_alg->cur_element_count;
}
void normalize_vector(float *data, float *norm_array){
float norm=0.0f;
for(int i=0;i<dim;i++)
norm+=data[i]*data[i];
norm= 1.0f / (sqrtf(norm) + 1e-30f);
for(int i=0;i<dim;i++)
norm_array[i]=data[i]*norm;
}
void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
if (num_threads <= 0)
num_threads = num_threads_default;
size_t rows, features;
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
}
else{
rows = 1;
features = buffer.shape[0];
}
if (features != dim)
throw std::runtime_error("wrong dimensionality of the vectors");
// avoid using threads when the number of searches is small:
if(rows<=num_threads*4){
num_threads=1;
}
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
if(ids_numpy.ndim==1 && ids_numpy.shape[0]==rows) {
std::vector<size_t> ids1(ids_numpy.shape[0]);
for (size_t i = 0; i < ids1.size(); i++) {
ids1[i] = items.data()[i];
}
ids.swap(ids1);
}
else if(ids_numpy.ndim==0 && rows==1) {
ids.push_back(*items.data());
}
else
throw std::runtime_error("wrong dimensionality of the labels");
}
{
int start = 0;
if (!ep_added) {
size_t id = ids.size() ? ids.at(0) : (cur_l);
float *vector_data=(float *) items.data(0);
std::vector<float> norm_array(dim);
if(normalize){
normalize_vector(vector_data, norm_array.data());
vector_data = norm_array.data();
}
appr_alg->addPoint((void *) vector_data, (size_t) id);
start = 1;
ep_added = true;
}
py::gil_scoped_release l;
if(normalize==false) {
ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) {
size_t id = ids.size() ? ids.at(row) : (cur_l+row);
appr_alg->addPoint((void *) items.data(row), (size_t) id);
});
} else{
std::vector<float> norm_array(num_threads * dim);
ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) {
// normalize vector:
size_t start_idx = threadId * dim;
normalize_vector((float *) items.data(row), (norm_array.data()+start_idx));
size_t id = ids.size() ? ids.at(row) : (cur_l+row);
appr_alg->addPoint((void *) (norm_array.data()+start_idx), (size_t) id);
});
};
cur_l+=rows;
}
}
std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
auto ids_numpy = items.request();
std::vector<size_t> ids1(ids_numpy.shape[0]);
for (size_t i = 0; i < ids1.size(); i++) {
ids1[i] = items.data()[i];
}
ids.swap(ids1);
}
std::vector<std::vector<data_t>> data;
for (auto id : ids) {
data.push_back(appr_alg->template getDataByLabel<data_t>(id));
}
return data;
}
std::vector<unsigned int> getIdsList() {
std::vector<unsigned int> ids;
for(auto kv : appr_alg->label_lookup_) {
ids.push_back(kv.first);
}
return ids;
}
py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
auto buffer = items.request();
hnswlib::labeltype *data_numpy_l;
dist_t *data_numpy_d;
size_t rows, features;
if (num_threads <= 0)
num_threads = num_threads_default;
{
py::gil_scoped_release l;
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
if (buffer.ndim == 2) {
rows = buffer.shape[0];
features = buffer.shape[1];
}
else{
rows = 1;
features = buffer.shape[0];
}
// avoid using threads when the number of searches is small:
if(rows<=num_threads*4){
num_threads=1;
}
data_numpy_l = new hnswlib::labeltype[rows * k];
data_numpy_d = new dist_t[rows * k];
if(normalize==false) {
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void *) items.data(row), k);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is to small");
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
}
);
}
else{
std::vector<float> norm_array(num_threads*features);
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
float *data= (float *) items.data(row);
size_t start_idx = threadId * dim;
normalize_vector((float *) items.data(row), (norm_array.data()+start_idx));
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
(void *) (norm_array.data()+start_idx), k);
if (result.size() != k)
throw std::runtime_error(
"Cannot return the results in a contigious 2D array. Probably ef or M is to small");
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();
data_numpy_d[row * k + i] = result_tuple.first;
data_numpy_l[row * k + i] = result_tuple.second;
result.pop();
}
}
);
}
}
py::capsule free_when_done_l(data_numpy_l, [](void *f) {
delete[] f;
});
py::capsule free_when_done_d(data_numpy_d, [](void *f) {
delete[] f;
});
return py::make_tuple(
py::array_t<hnswlib::labeltype>(
{rows, k}, // shape
{k * sizeof(hnswlib::labeltype),
sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double
data_numpy_l, // the data pointer
free_when_done_l),
py::array_t<dist_t>(
{rows, k}, // shape
{k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for double
data_numpy_d, // the data pointer
free_when_done_d));
}
std::string space_name;
int dim;
bool index_inited;
bool ep_added;
bool normalize;
int num_threads_default;
hnswlib::labeltype cur_l;
hnswlib::HierarchicalNSW<dist_t> *appr_alg;
hnswlib::SpaceInterface<float> *l2space;
~Index() {
delete l2space;
if (appr_alg)
delete appr_alg;
}
};
PYBIND11_PLUGIN(hnswlib) {
py::module m("hnswlib");
py::class_<Index<float>>(m, "Index")
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
.def("init_index", &Index<float>::init_new_index, py::arg("max_elements"), py::arg("M")=16,
py::arg("ef_construction")=200, py::arg("random_seed")=100)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1)
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1)
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_ids_list", &Index<float>::getIdsList)
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))
.def("save_index", &Index<float>::saveIndex, py::arg("path_to_index"))
.def("load_index", &Index<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0)
.def("__repr__",
[](const Index<float> &a) {
return "<HNSW-lib index>";
}
);
return m.ptr();
}
You can’t perform that action at this time.