diff --git a/CMakeLists.txt b/CMakeLists.txt index 79ab30b3..ebee6e6c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ include_directories("${PROJECT_BINARY_DIR}") -set(SOURCE_EXE main.cpp) +set(SOURCE_EXE main.cpp) set(SOURCE_LIB sift_1b.cpp) @@ -13,5 +13,14 @@ add_library(sift_test STATIC ${SOURCE_LIB}) add_executable(main ${SOURCE_EXE}) +if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + SET( CMAKE_CXX_FLAGS "-Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -ftree-vectorize") +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) +endif() + +add_executable(test_updates examples/updates_test.cpp) + target_link_libraries(main sift_test) diff --git a/README.md b/README.md index c79e24c1..70b03cad 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,29 @@ To run the test on 200M SIFT subset: The size of the bigann subset (in millions) is controlled by the variable **subset_size_milllions** hardcoded in **sift_1b.cpp**. +### Updates test +To generate testing data (from root directory): +```bash +cd examples +python update_gen_data.py +``` +To compile (from root directory): +```bash +mkdir build +cd build +cmake .. +make +``` +To run test **without** updates (from `build` directory) +```bash +./test_updates +``` + +To run test **with** updates (from `build` directory) +```bash +./test_updates update +``` + ### HNSW example demos - Visual search engine for 1M amazon products (MXNet + HNSW): [website](https://thomasdelteil.github.io/VisualSearch_MXNet/), [code](https://github.com/ThomasDelteil/VisualSearch_MXNet), demo by [@ThomasDelteil](https://github.com/ThomasDelteil) diff --git a/examples/update_gen_data.py b/examples/update_gen_data.py new file mode 100644 index 00000000..6f51bbbe --- /dev/null +++ b/examples/update_gen_data.py @@ -0,0 +1,37 @@ +import numpy as np +import os + +def normalized(a, axis=-1, order=2): + l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) + l2[l2==0] = 1 + return a / np.expand_dims(l2, axis) + +N=100000 +dummy_data_multiplier=3 +N_queries = 1000 +d=8 +K=5 + +np.random.seed(1) + +print("Generating data...") +batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)] +batch_final = normalized (np.float32(np.random.random( (N,d)))) +queries = normalized(np.float32(np.random.random( (N_queries,d)))) +print("Computing distances...") +dist=np.dot(queries,batch_final.T) +topk=np.argsort(-dist)[:,:K] +print("Saving...") + +try: + os.mkdir("data") +except OSError as e: + pass + +for idx, batch_dummy in enumerate(batches_dummy): + batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx) +batch_final.tofile('data/batch_final.bin') +queries.tofile('data/queries.bin') +np.int32(topk).tofile('data/gt.bin') +with open("data/config.txt", "w") as file: + file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K)) \ No newline at end of file diff --git a/examples/updates_test.cpp b/examples/updates_test.cpp new file mode 100644 index 00000000..c8775877 --- /dev/null +++ b/examples/updates_test.cpp @@ -0,0 +1,298 @@ +#include "../hnswlib/hnswlib.h" +#include +class StopW +{ + std::chrono::steady_clock::time_point time_begin; + +public: + StopW() + { + time_begin = std::chrono::steady_clock::now(); + } + + float getElapsedTimeMicro() + { + std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); + return (std::chrono::duration_cast(time_end - time_begin).count()); + } + + void reset() + { + time_begin = std::chrono::steady_clock::now(); + } +}; + +/* + * replacement for the openmp '#pragma omp parallel for' directive + * only handles a subset of functionality (no reductions etc) + * Process ids from start (inclusive) to end (EXCLUSIVE) + * + * The method is borrowed from nmslib + */ +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if ((id >= end)) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } + + +} + + +template +std::vector load_batch(std::string path, int size) +{ + std::cout << "Loading " << path << "..."; + // float or int32 (python) + assert(sizeof(datatype) == 4); + + std::ifstream file; + file.open(path); + if (!file.is_open()) + { + std::cout << "Cannot open " << path << "\n"; + exit(1); + } + std::vector batch(size); + + file.read((char *)batch.data(), size * sizeof(float)); + std::cout << " DONE\n"; + return batch; +} + +template +static float +test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, + std::vector> &answers, size_t K) +{ + size_t correct = 0; + size_t total = 0; + //uncomment to test in parallel mode: + + + for (int i = 0; i < qsize; i++) + { + + std::priority_queue> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K); + total += K; + while (result.size()) + { + if (answers[i].find(result.top().second) != answers[i].end()) + { + correct++; + } + else + { + } + result.pop(); + } + } + return 1.0f * correct / total; +} + +static void +test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, + std::vector> &answers, size_t k) +{ + std::vector efs = {1}; + for (int i = k; i < 30; i++) + { + efs.push_back(i); + } + for (int i = 30; i < 400; i+=10) + { + efs.push_back(i); + } + for (int i = 1000; i < 100000; i += 5000) + { + efs.push_back(i); + } + std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; + for (size_t ef : efs) + { + appr_alg.setEf(ef); + + appr_alg.metric_hops=0; + appr_alg.metric_distance_computations=0; + StopW stopw = StopW(); + + float recall = test_approx(queries, qsize, appr_alg, vecdim, answers, k); + float time_us_per_query = stopw.getElapsedTimeMicro() / qsize; + float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize); + float hops_per_query = appr_alg.metric_hops / (1.0f * qsize); + + std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"< 0.99) + { + std::cout << "Recall is over 0.99! "<2){ + std::cout<<"Usage ./test_updates [update]\n"; + exit(1); + } + + std::string path = "../examples/data/"; + + + int N; + int dummy_data_multiplier; + int N_queries; + int d; + int K; + { + std::ifstream configfile; + configfile.open(path + "/config.txt"); + if (!configfile.is_open()) + { + std::cout << "Cannot open config.txt\n"; + return 1; + } + configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K; + + printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K); + } + + hnswlib::L2Space l2space(d); + hnswlib::HierarchicalNSW appr_alg(&l2space, N + 1, M, efConstruction); + + std::vector dummy_batch = load_batch(path + "batch_dummy_00.bin", N * d); + + // Adding enterpoint: + + appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0); + + StopW stopw = StopW(); + + if (update) + { + std::cout << "Update iteration 0\n"; + + + ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + + ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + + for (int b = 1; b < dummy_data_multiplier; b++) + { + std::cout << "Update iteration " << b << "\n"; + char cpath[1024]; + sprintf(cpath, "batch_dummy_%02d.bin", b); + std::vector dummy_batchb = load_batch(path + cpath, N * d); + + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); + }); + appr_alg.checkIntegrity(); + } + } + + std::cout << "Inserting final elements\n"; + std::vector final_batch = load_batch(path + "batch_final.bin", N * d); + + stopw.reset(); + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + appr_alg.addPoint((void *)(final_batch.data() + i * d), i); + }); + std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; + std::cout << "Running tests\n"; + std::vector queries_batch = load_batch(path + "queries.bin", N_queries * d); + + std::vector gt = load_batch(path + "gt.bin", N_queries * K); + + std::vector> answers(N_queries); + for (int i = 0; i < N_queries; i++) + { + for (int j = 0; j < K; j++) + { + answers[i].insert(gt[i * K + j]); + } + } + + for (int i = 0; i < 3; i++) + { + std::cout << "Test iteration " << i << "\n"; + test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K); + } + + return 0; +}; \ No newline at end of file diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index afc1222d..97bdcd18 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -2,6 +2,7 @@ #include "visited_list_pool.h" #include "hnswlib.h" +#include #include #include #include @@ -15,7 +16,7 @@ namespace hnswlib { template class HierarchicalNSW : public AlgorithmInterface { public: - + static const tableint max_update_element_locks = 65536; HierarchicalNSW(SpaceInterface *s) { } @@ -25,7 +26,7 @@ namespace hnswlib { } HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), element_levels_(max_elements) { + link_list_locks_(max_elements), element_levels_(max_elements), link_list_update_locks_(max_update_element_locks) { max_elements_ = max_elements; has_deletions_=false; @@ -39,6 +40,7 @@ namespace hnswlib { ef_ = 10; level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); @@ -104,6 +106,10 @@ namespace hnswlib { std::mutex cur_element_count_guard_; std::vector link_list_locks_; + + // Locks to prevent race condition during update/insert of an element at same time. + // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. + std::vector link_list_update_locks_; tableint enterpoint_node_; @@ -126,6 +132,7 @@ namespace hnswlib { std::unordered_map label_lookup_; std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; inline labeltype getExternalLabel(tableint internal_id) const { labeltype return_label; @@ -151,6 +158,7 @@ namespace hnswlib { return (int) r; } + std::priority_queue, std::vector>, CompareByFirst> searchBaseLayer(tableint ep_id, const void *data_point, int layer) { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); @@ -233,7 +241,10 @@ namespace hnswlib { return top_candidates; } - template + mutable std::atomic metric_distance_computations; + mutable std::atomic metric_hops; + + template std::priority_queue, std::vector>, CompareByFirst> searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { VisitedList *vl = visited_list_pool_->getFreeVisitedList(); @@ -269,6 +280,10 @@ namespace hnswlib { int *data = (int *) get_linklist0(current_node_id); size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); + if(collect_metrics){ + metric_hops++; + metric_distance_computations+=size; + } #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); @@ -319,10 +334,11 @@ namespace hnswlib { void getNeighborsByHeuristic2( std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { + const size_t M) { if (top_candidates.size() < M) { return; } + std::priority_queue> queue_closest; std::vector> return_list; while (top_candidates.size() > 0) { @@ -337,6 +353,7 @@ namespace hnswlib { dist_t dist_to_query = -curent_pair.first; queue_closest.pop(); bool good = true; + for (std::pair second_pair : return_list) { dist_t curdist = fstdistfunc_(getDataByInternalId(second_pair.second), @@ -350,12 +367,9 @@ namespace hnswlib { if (good) { return_list.push_back(curent_pair); } - - } for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); } } @@ -373,10 +387,13 @@ namespace hnswlib { return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); }; - void mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> top_candidates, - int level) { + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + }; + tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, bool isUpdate) { size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) @@ -389,6 +406,8 @@ namespace hnswlib { top_candidates.pop(); } + tableint next_closest_entry_point = selectedNeighbors[0]; + { linklistsizeint *ll_cur; if (level == 0) @@ -396,15 +415,13 @@ namespace hnswlib { else ll_cur = get_linklist(cur_c, level); - if (*ll_cur) { + if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); } setListCount(ll_cur,selectedNeighbors.size()); tableint *data = (tableint *) (ll_cur + 1); - - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx]) + if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); @@ -413,11 +430,11 @@ namespace hnswlib { } } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - linklistsizeint *ll_other; if (level == 0) ll_other = get_linklist0(selectedNeighbors[idx]); @@ -434,47 +451,63 @@ namespace hnswlib { throw std::runtime_error("Trying to make a link on a non-existent level"); tableint *data = (tableint *) (ll_other + 1); - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); + bool is_cur_c_present = false; + if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); + if (data[j] == cur_c) { + is_cur_c_present = true; + break; + } } + } - getNeighborsByHeuristic2(candidates, Mcurmax); + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; + getNeighborsByHeuristic2(candidates, Mcurmax); + + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } + + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; + } } + if (indx >= 0) { + data[indx] = cur_c; + } */ } - if (indx >= 0) { - data[indx] = cur_c; - } */ } - } + + return next_closest_entry_point; } std::mutex global; @@ -516,15 +549,15 @@ namespace hnswlib { if (has_deletions_) { std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); + ef_); top_candidates.swap(top_candidates1); } else{ std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); + ef_); top_candidates.swap(top_candidates1); } - + while (top_candidates.size() > k) { top_candidates.pop(); } @@ -545,7 +578,6 @@ namespace hnswlib { std::vector(new_max_elements).swap(link_list_locks_); - // Reallocate base layer char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); if (data_level0_memory_new == nullptr) @@ -636,8 +668,8 @@ namespace hnswlib { dist_func_param_ = s->get_dist_func_param(); auto pos=input.tellg(); - - + + /// Optional - check if index is ok: input.seekg(cur_element_count * size_data_per_element_,input.cur); @@ -669,7 +701,7 @@ namespace hnswlib { throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -677,6 +709,7 @@ namespace hnswlib { size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); std::vector(max_elements).swap(link_list_locks_); + std::vector(max_update_element_locks).swap(link_list_update_locks_); visited_list_pool_ = new VisitedListPool(1, max_elements); @@ -711,7 +744,7 @@ namespace hnswlib { if(isMarkedDeleted(i)) has_deletions_=true; } - + input.close(); return; @@ -795,26 +828,185 @@ namespace hnswlib { addPoint(data_point, label,-1); } + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); + + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; + + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; + + sCand.insert(internalId); + + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); + + if (distribution(update_probability_generator_) > updateNeighborProbability) + continue; + + sNeigh.insert(elOneHop); + + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); + } + } + + for (auto&& neigh : sNeigh) { +// if (neigh == internalId) +// continue; + + std::priority_queue, std::vector>, CompareByFirst> candidates; + int size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; + int elementsToKeep = std::min(int(ef_construction_), size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; + + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); + candidates.emplace(distance, cand); + } + } + } + + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + int candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); + } + } + } + } + + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + }; + + void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj,level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); +#endif + for (int i = 0; i < size; i++) { +#ifdef USE_SSE + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); +#endif + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + } + + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); + + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); + + topCandidates.pop(); + } + + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); + } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + } + } + } + + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll,size * sizeof(tableint)); + return result; + }; + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; { - std::unique_lock lock(cur_element_count_guard_); + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock templock_curr(cur_element_count_guard_); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + + templock_curr.unlock(); + + std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); + updatePoint(data_point, existingInternalId, 1.0); + return existingInternalId; + } + if (cur_element_count >= max_elements_) { throw std::runtime_error("The number of elements exceeds the specified limit"); }; cur_c = cur_element_count; cur_element_count++; - - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - std::unique_lock lock_el(link_list_locks_[search->second]); - has_deletions_ = true; - markDeletedInternal(search->second); - } label_lookup_[label] = cur_c; } + // Take update lock to prevent race conditions on an element with insertion/update at the same time. + std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); std::unique_lock lock_el(link_list_locks_[cur_c]); int curlevel = getRandomLevel(mult_); if (level > 0) @@ -889,9 +1081,7 @@ namespace hnswlib { if (top_candidates.size() > ef_construction_) top_candidates.pop(); } - mutuallyConnectNewElement(data_point, cur_c, top_candidates, level); - - currObj = top_candidates.top().second; + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } @@ -926,6 +1116,9 @@ namespace hnswlib { data = (unsigned int *) get_linklist(currObj, level); int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; + tableint *datal = (tableint *) (data + 1); for (int i = 0; i < size; i++) { tableint cand = datal[i]; @@ -943,16 +1136,15 @@ namespace hnswlib { } std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (has_deletions_) { - std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + if (has_deletions_) { + top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); - top_candidates.swap(top_candidates1); } else{ - std::priority_queue, std::vector>, CompareByFirst> top_candidates1=searchBaseLayerST( + top_candidates=searchBaseLayerST( currObj, query_data, std::max(ef_, k)); - top_candidates.swap(top_candidates1); } + while (top_candidates.size() > k) { top_candidates.pop(); } @@ -982,6 +1174,40 @@ namespace hnswlib { return result; } + void checkIntegrity(){ + int connections_checked=0; + std::vector inbound_connections_num(cur_element_count,0); + for(int i = 0;i < cur_element_count; i++){ + for(int l = 0;l <= element_levels_[i]; l++){ + linklistsizeint *ll_cur = get_linklist_at_level(i,l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j=0; j 0); + assert(data[j] < cur_element_count); + assert (data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; + + } + assert(s.size() == size); + } + } + if(cur_element_count > 1){ + int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; + for(int i=0; i < cur_element_count; i++){ + assert(inbound_connections_num[i] > 0); + min1=std::min(inbound_connections_num[i],min1); + max1=std::max(inbound_connections_num[i],max1); + } + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + + } + }; } diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index dbfb1656..c26f80b5 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -25,7 +25,7 @@ #include #include - +#include #include namespace hnswlib {