diff --git a/README.md b/README.md index 559c5dfd..4d74b003 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. #### Short API description * `hnswlib.Index(space, dim)` creates a non-initialized index an HNSW in space `space` with integer dimension `dim`. -Index methods: +`hnswlib.Index` methods: * `init_index(max_elements, ef_construction = 200, M = 16, random_seed = 100)` initializes the index from with no elements. * `max_elements` defines the maximum number of elements that can be stored in the structure(can be increased/shrunk). * `ef_construction` defines a construction time/accuracy trade-off (see [ALGO_PARAMS.md](ALGO_PARAMS.md)). @@ -76,14 +76,34 @@ Index methods: * `get_current_count()` - returns the current number of element stored in the index - - +Read-only properties of `hnswlib.Index` class: + +* `space` - name of the space (can be one of "l2", "ip", or "cosine"). + +* `dim` - dimensionality of the space. + +* `M` - parameter that defines the maximum number of outgoing connections in the graph. + +* `ef_construction` - parameter that controls speed/accuracy trade-off during the index construction. + +* `max_elements` - current capacity of the index. Equivalent to `p.get_max_elements()`. + +* `element_count` - number of items in the index. Equivalent to `p.get_current_count()`. + +Properties of `hnswlib.Index` that support reading and writing: + +* `ef` - parameter controlling query time/accuracy trade-off. + +* `num_threads` - default number of threads to use in `add_items` or `knn_query`. Note that calling `p.set_num_threads(3)` is equivalent to `p.num_threads=3`. + + #### Python bindings examples ```python import hnswlib import numpy as np +import pickle dim = 128 num_elements = 10000 @@ -106,6 +126,18 @@ p.set_ef(50) # ef should always be > k # Query dataset, k - number of closest elements (returns 2 numpy arrays) labels, distances = p.knn_query(data, k = 1) + +# Index objects support pickling +# WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! +# Note: ef parameter is included in serialization; random number generator is initialized with random_seeed on Index load +p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip + +### Index parameters are exposed as class properties: +print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") +print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") +print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") +print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") + ``` An example with updates after serialization/deserialization: diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 7d0eb443..7c2c01c3 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -637,7 +637,6 @@ namespace hnswlib { if (!input.is_open()) throw std::runtime_error("Cannot open file"); - // get file size: input.seekg(0,input.end); std::streampos total_filesize=input.tellg(); @@ -874,7 +873,7 @@ namespace hnswlib { for (auto&& cand : sCand) { if (cand == neigh) continue; - + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); if (candidates.size() < elementsToKeep) { candidates.emplace(distance, cand); @@ -1137,7 +1136,7 @@ namespace hnswlib { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (has_deletions_) { + if (has_deletions_) { top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); } @@ -1186,19 +1185,19 @@ namespace hnswlib { std::unordered_set s; for (int j=0; j 0); - assert(data[j] < cur_element_count); + assert(data[j] < cur_element_count); assert (data[j] != i); inbound_connections_num[data[j]]++; s.insert(data[j]); connections_checked++; - + } assert(s.size() == size); } } if(cur_element_count > 1){ int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; - for(int i=0; i < cur_element_count; i++){ + for(int i=0; i < cur_element_count; i++){ assert(inbound_connections_num[i] > 0); min1=std::min(inbound_connections_num[i],min1); max1=std::max(inbound_connections_num[i],max1); @@ -1206,7 +1205,7 @@ namespace hnswlib { std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; - + } }; diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 1b88ca23..d9396247 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -5,6 +5,8 @@ #include "hnswlib/hnswlib.h" #include #include +#include +#include namespace py = pybind11; @@ -13,7 +15,7 @@ namespace py = pybind11; * only handles a subset of functionality (no reductions etc) * Process ids from start (inclusive) to end (EXCLUSIVE) * - * The method is borrowed from nmslib + * The method is borrowed from nmslib */ template inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { @@ -71,27 +73,52 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn } + + template class Index { public: - Index(const std::string &space_name, const int dim) : - space_name(space_name), dim(dim) { - normalize=false; - if(space_name=="l2") { - l2space = new hnswlib::L2Space(dim); - } - else if(space_name=="ip") { - l2space = new hnswlib::InnerProductSpace(dim); - } - else if(space_name=="cosine") { - l2space = new hnswlib::InnerProductSpace(dim); - normalize=true; - } - appr_alg = NULL; - ep_added = true; - index_inited = false; - num_threads_default = std::thread::hardware_concurrency(); + Index(const std::string &space_name, const int dim) : + space_name(space_name), dim(dim) { + normalize=false; + if(space_name=="l2") { + l2space = new hnswlib::L2Space(dim); } + else if(space_name=="ip") { + l2space = new hnswlib::InnerProductSpace(dim); + } + else if(space_name=="cosine") { + l2space = new hnswlib::InnerProductSpace(dim); + normalize=true; + } + appr_alg = NULL; + ep_added = true; + index_inited = false; + num_threads_default = std::thread::hardware_concurrency(); + + default_ef=10; + } + + static const int ser_version = 1; // serialization version + + std::string space_name; + int dim; + size_t seed; + size_t default_ef; + + bool index_inited; + bool ep_added; + bool normalize; + int num_threads_default; + hnswlib::labeltype cur_l; + hnswlib::HierarchicalNSW *appr_alg; + hnswlib::SpaceInterface *l2space; + + ~Index() { + delete l2space; + if (appr_alg) + delete appr_alg; + } void init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) { if (appr_alg) { @@ -101,19 +128,17 @@ class Index { appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); index_inited = true; ep_added = false; + appr_alg->ef_ = default_ef; + seed=random_seed; } + void set_ef(size_t ef) { + default_ef=ef; + if (appr_alg) appr_alg->ef_ = ef; } - size_t get_ef_construction() { - return appr_alg->ef_construction_; - } - - size_t get_M() { - return appr_alg->M_; - } void set_num_threads(int num_threads) { this->num_threads_default = num_threads; @@ -124,21 +149,22 @@ class Index { } void loadIndex(const std::string &path_to_index, size_t max_elements) { - if (appr_alg) { - std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; - delete appr_alg; - } - appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); - cur_l = appr_alg->cur_element_count; + if (appr_alg) { + std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; + delete appr_alg; + } + appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); + cur_l = appr_alg->cur_element_count; + } + + void normalize_vector(float *data, float *norm_array){ + float norm=0.0f; + for(int i=0;i items(input); @@ -162,7 +188,6 @@ class Index { throw std::runtime_error("wrong dimensionality of the vectors"); // avoid using threads when the number of searches is small: - if(rows<=num_threads*4){ num_threads=1; } @@ -189,20 +214,19 @@ class Index { { - int start = 0; - if (!ep_added) { - size_t id = ids.size() ? ids.at(0) : (cur_l); - float *vector_data=(float *) items.data(0); - std::vector norm_array(dim); - if(normalize){ - normalize_vector(vector_data, norm_array.data()); - vector_data = norm_array.data(); - - } - appr_alg->addPoint((void *) vector_data, (size_t) id); - start = 1; - ep_added = true; + int start = 0; + if (!ep_added) { + size_t id = ids.size() ? ids.at(0) : (cur_l); + float *vector_data=(float *) items.data(0); + std::vector norm_array(dim); + if(normalize){ + normalize_vector(vector_data, norm_array.data()); + vector_data = norm_array.data(); } + appr_alg->addPoint((void *) vector_data, (size_t) id); + start = 1; + ep_added = true; + } py::gil_scoped_release l; if(normalize==false) { @@ -214,7 +238,7 @@ class Index { std::vector norm_array(num_threads * dim); ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { // normalize vector: - size_t start_idx = threadId * dim; + size_t start_idx = threadId * dim; normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); size_t id = ids.size() ? ids.at(row) : (cur_l+row); @@ -254,6 +278,247 @@ class Index { return ids; } + inline void assert_true(bool expr, const std::string & msg) { + if (expr == false) + throw std::runtime_error("assert failed: "+msg); + return; + } + + + py::tuple getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ + + std::unique_lock templock(appr_alg->global); + + unsigned int level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; + unsigned int link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); + + for (size_t i = 0; i < appr_alg->cur_element_count; i++){ + unsigned int linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i]=link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } + + char* data_level0_npy = (char *) malloc(level0_npy_size); + char* link_list_npy = (char *) malloc(link_npy_size); + int* element_levels_npy = (int *) malloc(appr_alg->element_levels_.size()*sizeof(int)); + + hnswlib::labeltype* label_lookup_key_npy = (hnswlib::labeltype *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); + hnswlib::tableint* label_lookup_val_npy = (hnswlib::tableint *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + + memset(label_lookup_key_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); + memset(label_lookup_val_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + + size_t idx=0; + for ( auto it = appr_alg->label_lookup_.begin(); it != appr_alg->label_lookup_.end(); ++it ){ + label_lookup_key_npy[idx]= it->first; + label_lookup_val_npy[idx]= it->second; + idx++; + } + + memset(link_list_npy, 0, link_npy_size); + + memcpy(data_level0_npy, appr_alg->data_level0_memory_, level0_npy_size); + memcpy(element_levels_npy, appr_alg->element_levels_.data(), appr_alg->element_levels_.size() * sizeof(int)); + + for (size_t i = 0; i < appr_alg->cur_element_count; i++){ + unsigned int linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize){ + memcpy(link_list_npy+link_npy_offsets[i], appr_alg->linkLists_[i], linkListSize); + } + } + + py::capsule free_when_done_l0(data_level0_npy, [](void *f) { + delete[] f; + }); + py::capsule free_when_done_lvl(element_levels_npy, [](void *f) { + delete[] f; + }); + py::capsule free_when_done_lb(label_lookup_key_npy, [](void *f) { + delete[] f; + }); + py::capsule free_when_done_id(label_lookup_val_npy, [](void *f) { + delete[] f; + }); + py::capsule free_when_done_ll(link_list_npy, [](void *f) { + delete[] f; + }); + + return py::make_tuple(appr_alg->offsetLevel0_, + appr_alg->max_elements_, + appr_alg->cur_element_count, + appr_alg->size_data_per_element_, + appr_alg->label_offset_, + appr_alg->offsetData_, + appr_alg->maxlevel_, + appr_alg->enterpoint_node_, + appr_alg->maxM_, + appr_alg->maxM0_, + appr_alg->M_, + appr_alg->mult_, + appr_alg->ef_construction_, + appr_alg->ef_, + appr_alg->has_deletions_, + appr_alg->size_links_per_element_, + py::array_t( + {appr_alg->label_lookup_.size()}, // shape + {sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double + label_lookup_key_npy, // the data pointer + free_when_done_lb), + py::array_t( + {appr_alg->label_lookup_.size()}, // shape + {sizeof(hnswlib::tableint)}, // C-style contiguous strides for double + label_lookup_val_npy, // the data pointer + free_when_done_id), + py::array_t( + {appr_alg->element_levels_.size()}, // shape + {sizeof(int)}, // C-style contiguous strides for double + element_levels_npy, // the data pointer + free_when_done_lvl), + py::array_t( + {level0_npy_size}, // shape + {sizeof(char)}, // C-style contiguous strides for double + data_level0_npy, // the data pointer + free_when_done_l0), + py::array_t( + {link_npy_size}, // shape + {sizeof(char)}, // C-style contiguous strides for double + link_list_npy, // the data pointer + free_when_done_ll) + ); + + } + + + py::tuple getIndexParams() const { + /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */ + /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */ + + return py::make_tuple(py::int_(Index::ser_version), // serialization version + + /* TODO: convert the following two py::tuple's to py::dict */ + py::make_tuple(space_name, dim, index_inited, ep_added, normalize, num_threads_default, seed, default_ef), + index_inited == true ? getAnnData() : py::make_tuple()); /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ + + + + } + + + static Index * createFromParams(const py::tuple t) { + + if (py::int_(Index::ser_version) != t[0].cast()) // check serialization version + throw std::runtime_error("Serialization version mismatch!"); + + py::tuple index_params=t[1].cast(); /* TODO: convert index_params from py::tuple to py::dict */ + py::tuple ann_params=t[2].cast(); /* TODO: convert ann_params from py::tuple to py::dict */ + + auto space_name_=index_params[0].cast(); + auto dim_=index_params[1].cast(); + auto index_inited_=index_params[2].cast(); + + Index *new_index = new Index(index_params[0].cast(), index_params[1].cast()); + + /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */ + /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */ + new_index->seed = index_params[6].cast(); + + if (index_inited_){ + new_index->appr_alg = new hnswlib::HierarchicalNSW(new_index->l2space, ann_params[1].cast(), ann_params[10].cast(), ann_params[12].cast(), new_index->seed); + new_index->cur_l = ann_params[2].cast(); + } + + new_index->index_inited = index_inited_; + new_index->ep_added=index_params[3].cast(); + new_index->num_threads_default=index_params[5].cast(); + new_index->default_ef=index_params[7].cast(); + + if (index_inited_) + new_index->setAnnData(ann_params); + + + return new_index; + } + + static Index * createFromIndex(const Index & index) { + /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */ + return createFromParams(index.getIndexParams()); + } + + + void setAnnData(const py::tuple t) { + /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */ + + std::unique_lock templock(appr_alg->global); + + assert_true(appr_alg->offsetLevel0_ == t[0].cast(), "Invalid value of offsetLevel0_ "); + assert_true(appr_alg->max_elements_ == t[1].cast(), "Invalid value of max_elements_ "); + + appr_alg->cur_element_count = t[2].cast(); + + assert_true(appr_alg->size_data_per_element_ == t[3].cast(), "Invalid value of size_data_per_element_ "); + assert_true(appr_alg->label_offset_ == t[4].cast(), "Invalid value of label_offset_ "); + assert_true(appr_alg->offsetData_ == t[5].cast(), "Invalid value of offsetData_ "); + + appr_alg->maxlevel_ = t[6].cast(); + appr_alg->enterpoint_node_ = t[7].cast(); + + assert_true(appr_alg->maxM_ == t[8].cast(), "Invalid value of maxM_ "); + assert_true(appr_alg->maxM0_ == t[9].cast(), "Invalid value of maxM0_ "); + assert_true(appr_alg->M_ == t[10].cast(), "Invalid value of M_ "); + assert_true(appr_alg->mult_ == t[11].cast(), "Invalid value of mult_ "); + assert_true(appr_alg->ef_construction_ == t[12].cast(), "Invalid value of ef_construction_ "); + + appr_alg->ef_ = t[13].cast(); + appr_alg->has_deletions_=t[14].cast(); + assert_true(appr_alg->size_links_per_element_ == t[15].cast(), "Invalid value of size_links_per_element_ "); + + auto label_lookup_key_npy = t[16].cast >(); + auto label_lookup_val_npy = t[17].cast >(); + auto element_levels_npy = t[18].cast >(); + auto data_level0_npy = t[19].cast >(); + auto link_list_npy = t[20].cast >(); + + for (size_t i = 0; i < appr_alg->cur_element_count; i++){ + if (label_lookup_val_npy.data()[i] < 0){ + throw std::runtime_error("internal id cannot be negative!"); + } + else{ + appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); + } + } + + memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes()); + + unsigned int link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); + + for (size_t i = 0; i < appr_alg->cur_element_count; i++){ + unsigned int linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i]=link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } + + memcpy(appr_alg->data_level0_memory_, data_level0_npy.data(), data_level0_npy.nbytes()); + + for (size_t i = 0; i < appr_alg->max_elements_; i++) { + unsigned int linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize == 0) { + appr_alg->linkLists_[i] = nullptr; + } else { + appr_alg->linkLists_[i] = (char *) malloc(linkListSize); + if (appr_alg->linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + + memcpy(appr_alg->linkLists_[i], link_list_npy.data()+link_npy_offsets[i], linkListSize); + + } + } + +} + py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); @@ -310,7 +575,7 @@ class Index { float *data= (float *) items.data(row); size_t start_idx = threadId * dim; - normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); + normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); std::priority_queue> result = appr_alg->searchKnn( (void *) (norm_array.data()+start_idx), k); @@ -367,50 +632,70 @@ class Index { return appr_alg->cur_element_count; } - std::string space_name; - int dim; +}; - bool index_inited; - bool ep_added; - bool normalize; - int num_threads_default; - hnswlib::labeltype cur_l; - hnswlib::HierarchicalNSW *appr_alg; - hnswlib::SpaceInterface *l2space; - - ~Index() { - delete l2space; - if (appr_alg) - delete appr_alg; - } -}; PYBIND11_PLUGIN(hnswlib) { py::module m("hnswlib"); py::class_>(m, "Index") + .def(py::init(&Index::createFromParams), py::arg("params")) + /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ + .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) - .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M")=16, - py::arg("ef_construction")=200, py::arg("random_seed")=100) + .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M")=16, py::arg("ef_construction")=200, py::arg("random_seed")=100) .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1) .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) - .def("get_ef_construction", &Index::get_ef_construction) - .def("get_M", &Index::get_M) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) .def("save_index", &Index::saveIndex, py::arg("path_to_index")) .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) - .def("get_max_elements", &Index::getMaxElements) - .def("get_current_count", &Index::getCurrentCount) - .def("__repr__", - [](const Index &a) { - return ""; - } - ); + .def_readonly("space_name", &Index::space_name) + .def_readonly("dim", &Index::dim) + .def_readwrite("num_threads", &Index::num_threads_default) + .def_property("ef", + [](const Index & index) { + return index.index_inited ? index.appr_alg->ef_ : index.default_ef; + }, + [](Index & index, const size_t ef_) { + index.default_ef=ef_; + if (index.appr_alg) + index.appr_alg->ef_ = ef_; + }) + .def_property_readonly("max_elements", [](const Index & index) { + return index.index_inited ? index.appr_alg->max_elements_ : 0; + }) + .def_property_readonly("element_count", [](const Index & index) { + return index.index_inited ? index.appr_alg->cur_element_count : 0; + }) + .def_property_readonly("ef_construction", [](const Index & index) { + return index.index_inited ? index.appr_alg->ef_construction_ : 0; + }) + .def_property_readonly("M", [](const Index & index) { + return index.index_inited ? index.appr_alg->M_ : 0; + }) + + .def(py::pickle( + [](const Index &ind) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + /* WARNING: Index::getIndexParams is not thread-safe with Index::addItems */ + return ind.getIndexParams(); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 3) + throw std::runtime_error("Invalid state!"); + return Index::createFromParams(t); + } + )) + + .def("__repr__", [](const Index &a) { + return ""; + }); + return m.ptr(); } diff --git a/python_bindings/tests/bindings_test_pickle.py b/python_bindings/tests/bindings_test_pickle.py new file mode 100644 index 00000000..6c3a826a --- /dev/null +++ b/python_bindings/tests/bindings_test_pickle.py @@ -0,0 +1,153 @@ +import unittest + +import numpy as np +import hnswlib +import pickle + + +def get_dist(metric, pt1, pt2): + if metric == 'l2': + return np.sum((pt1-pt2)**2) + elif metric == 'ip': + return 1. - np.sum(np.multiply(pt1,pt2)) + elif metric == 'cosine': + return 1. - np.sum(np.multiply(pt1,pt2)) / (np.sum(pt1**2) * np.sum(pt2**2))**.5 + +def brute_force_distances(metric, items, query_items, k): + dists=np.zeros((query_items.shape[0], items.shape[0])) + for ii in range(items.shape[0]): + for jj in range(query_items.shape[0]): + dists[jj,ii]=get_dist(metric, items[ii, :], query_items[jj, :]) + + labels = np.argsort(dists, axis=1) # equivalent, but faster: np.argpartition(dists, range(k), axis=1) + dists = np.sort(dists, axis=1) # equivalent, but faster: np.partition(dists, range(k), axis=1) + + return labels[:,:k], dists[:,:k] + + +def check_ann_results(self, metric, items, query_items, k, ann_l, ann_d, err_thresh=0, total_thresh=0, dists_thresh=0): + brute_l, brute_d = brute_force_distances(metric, items, query_items, k) + err_total = 0 + for jj in range(query_items.shape[0]): + err = np.sum(np.isin(brute_l[jj, :], ann_l[jj, :], invert=True)) + if err > 0: + print(f"Warning: {err} labels are missing from ann results (k={k}, err_thresh={err_thresh})") + + if err > err_thresh: + err_total += 1 + + self.assertLessEqual( err_total, total_thresh, f"Error: knn_query returned incorrect labels for {err_total} items (k={k})") + + wrong_dists=np.sum(((brute_d- ann_d)**2.)>1e-3) + if wrong_dists > 0: + dists_count=brute_d.shape[0]*brute_d.shape[1] + print(f"Warning: {wrong_dists} ann distance values are different from brute-force values (total # of values={dists_count}, dists_thresh={dists_thresh})") + + self.assertLessEqual( wrong_dists, dists_thresh, msg=f"Error: {wrong_dists} ann distance values are different from brute-force values") + +def test_space_main(self, space, dim): + + # Generating sample data + data = np.float32(np.random.random((self.num_elements, dim))) + test_data = np.float32(np.random.random((self.num_test_elements, dim))) + + # Declaring index + p = hnswlib.Index(space=space, dim=dim) # possible options are l2, cosine or ip + print(f"Running pickle tests for {p}") + + p.num_threads=self.num_threads # by default using all available cores + + p0=pickle.loads(pickle.dumps(p)) ### pickle un-initialized Index + p.init_index(max_elements = self.num_elements, ef_construction = self.ef_construction, M = self.M) + p0.init_index(max_elements = self.num_elements, ef_construction = self.ef_construction, M = self.M) + + p.ef=self.ef + p0.ef=self.ef + + p1=pickle.loads(pickle.dumps(p)) ### pickle Index before adding items + + ### add items to ann index p,p0,p1 + p.add_items(data) + p1.add_items(data) + p0.add_items(data) + + p2=pickle.loads(pickle.dumps(p)) ### pickle Index before adding items + + self.assertTrue(np.allclose(p.get_items(), p0.get_items()), "items for p and p0 must be same") + self.assertTrue(np.allclose(p0.get_items(), p1.get_items()), "items for p0 and p1 must be same") + self.assertTrue(np.allclose(p1.get_items(), p2.get_items()), "items for p1 and p2 must be same") + + ### Test if returned distances are same + l, d = p.knn_query(test_data, k=self.k) + l0, d0 = p0.knn_query(test_data, k=self.k) + l1, d1 = p1.knn_query(test_data, k=self.k) + l2, d2 = p2.knn_query(test_data, k=self.k) + + self.assertLessEqual(np.sum(((d-d0)**2.)>1e-3), self.dists_err_thresh, msg=f"knn distances returned by p and p0 must match") + self.assertLessEqual(np.sum(((d0-d1)**2.)>1e-3), self.dists_err_thresh, msg=f"knn distances returned by p0 and p1 must match") + self.assertLessEqual(np.sum(((d1-d2)**2.)>1e-3), self.dists_err_thresh, msg=f"knn distances returned by p1 and p2 must match") + + ### check if ann results match brute-force search + ### allow for 2 labels to be missing from ann results + check_ann_results(self, space, data, test_data, self.k, l, d, + err_thresh = self.label_err_thresh, + total_thresh = self.item_err_thresh, + dists_thresh = self.dists_err_thresh) + + check_ann_results(self, space, data, test_data, self.k, l2, d2, + err_thresh=self.label_err_thresh, + total_thresh=self.item_err_thresh, + dists_thresh=self.dists_err_thresh) + + ### Check ef parameter value + self.assertEqual(p.ef, self.ef, "incorrect value of p.ef") + self.assertEqual(p0.ef, self.ef, "incorrect value of p0.ef") + self.assertEqual(p2.ef, self.ef, "incorrect value of p2.ef") + self.assertEqual(p1.ef, self.ef, "incorrect value of p1.ef") + + ### Check M parameter value + self.assertEqual(p.M, self.M, "incorrect value of p.M") + self.assertEqual(p0.M, self.M, "incorrect value of p0.M") + self.assertEqual(p1.M, self.M, "incorrect value of p1.M") + self.assertEqual(p2.M, self.M, "incorrect value of p2.M") + + ### Check ef_construction parameter value + self.assertEqual(p.ef_construction, self.ef_construction, "incorrect value of p.ef_construction") + self.assertEqual(p0.ef_construction, self.ef_construction, "incorrect value of p0.ef_construction") + self.assertEqual(p1.ef_construction, self.ef_construction, "incorrect value of p1.ef_construction") + self.assertEqual(p2.ef_construction, self.ef_construction, "incorrect value of p2.ef_construction") + + + +class PickleUnitTests(unittest.TestCase): + + def setUp(self): + + self.ef_construction = 725 + self.M = 64 + self.ef = 725 + + self.num_elements = 5000 + self.num_test_elements = 200 + + self.num_threads = 4 + self.k = 25 + + self.label_err_thresh=5 ### max number of missing labels allowed per test item + self.item_err_thresh=5 ### max number of items allowed with incorrect labels + + self.dists_err_thresh=50 ### for two matrices, d1 and d2, dists_err_thresh controls max + ### number of value pairs that are allowed to be different in d1 and d2 + ### i.e., number of values that are (d1-d2)**2>1e-3 + + def test_inner_product_space(self): + test_space_main(self, 'ip', 48) + + def test_l2_space(self): + test_space_main(self, 'l2', 153) + + def test_cosine_space(self): + test_space_main(self, 'cosine', 512) + +if __name__ == "__main__": + unittest.main()