Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@ include_directories("${PROJECT_BINARY_DIR}")



set(SOURCE_EXE main.cpp)
set(SOURCE_EXE main.cpp)

set(SOURCE_LIB sift_1b.cpp)

add_library(sift_test STATIC ${SOURCE_LIB})


add_executable(main ${SOURCE_EXE})
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
SET( CMAKE_CXX_FLAGS "-Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -ftree-vectorize")
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" )
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" )
endif()

add_executable(test_updates examples/updates_test.cpp)

target_link_libraries(main sift_test)
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,29 @@ To run the test on 200M SIFT subset:

The size of the bigann subset (in millions) is controlled by the variable **subset_size_milllions** hardcoded in **sift_1b.cpp**.

### Updates test
To generate testing data (from root directory):
```bash
cd examples
python update_gen_data.py
```
To compile (from root directory):
```bash
mkdir build
cd build
cmake ..
make
```
To run test **without** updates (from `build` directory)
```bash
./test_updates
```

To run test **with** updates (from `build` directory)
```bash
./test_updates update
```

### HNSW example demos

- Visual search engine for 1M amazon products (MXNet + HNSW): [website](https://thomasdelteil.github.io/VisualSearch_MXNet/), [code](https://github.com/ThomasDelteil/VisualSearch_MXNet), demo by [@ThomasDelteil](https://github.com/ThomasDelteil)
Expand Down
37 changes: 37 additions & 0 deletions examples/update_gen_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import os

def normalized(a, axis=-1, order=2):
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2==0] = 1
return a / np.expand_dims(l2, axis)

N=100000
dummy_data_multiplier=3
N_queries = 1000
d=8
K=5

np.random.seed(1)

print("Generating data...")
batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)]
batch_final = normalized (np.float32(np.random.random( (N,d))))
queries = normalized(np.float32(np.random.random( (N_queries,d))))
print("Computing distances...")
dist=np.dot(queries,batch_final.T)
topk=np.argsort(-dist)[:,:K]
print("Saving...")

try:
os.mkdir("data")
except OSError as e:
pass

for idx, batch_dummy in enumerate(batches_dummy):
batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx)
batch_final.tofile('data/batch_final.bin')
queries.tofile('data/queries.bin')
np.int32(topk).tofile('data/gt.bin')
with open("data/config.txt", "w") as file:
file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K))
298 changes: 298 additions & 0 deletions examples/updates_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
#include "../hnswlib/hnswlib.h"
#include <thread>
class StopW
{
std::chrono::steady_clock::time_point time_begin;

public:
StopW()
{
time_begin = std::chrono::steady_clock::now();
}

float getElapsedTimeMicro()
{
std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count());
}

void reset()
{
time_begin = std::chrono::steady_clock::now();
}
};

/*
* replacement for the openmp '#pragma omp parallel for' directive
* only handles a subset of functionality (no reductions etc)
* Process ids from start (inclusive) to end (EXCLUSIVE)
*
* The method is borrowed from nmslib
*/
template<class Function>
inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
if (numThreads <= 0) {
numThreads = std::thread::hardware_concurrency();
}

if (numThreads == 1) {
for (size_t id = start; id < end; id++) {
fn(id, 0);
}
} else {
std::vector<std::thread> threads;
std::atomic<size_t> current(start);

// keep track of exceptions in threads
// https://stackoverflow.com/a/32428427/1713196
std::exception_ptr lastException = nullptr;
std::mutex lastExceptMutex;

for (size_t threadId = 0; threadId < numThreads; ++threadId) {
threads.push_back(std::thread([&, threadId] {
while (true) {
size_t id = current.fetch_add(1);

if ((id >= end)) {
break;
}

try {
fn(id, threadId);
} catch (...) {
std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
lastException = std::current_exception();
/*
* This will work even when current is the largest value that
* size_t can fit, because fetch_add returns the previous value
* before the increment (what will result in overflow
* and produce 0 instead of current + 1).
*/
current = end;
break;
}
}
}));
}
for (auto &thread : threads) {
thread.join();
}
if (lastException) {
std::rethrow_exception(lastException);
}
}


}


template <typename datatype>
std::vector<datatype> load_batch(std::string path, int size)
{
std::cout << "Loading " << path << "...";
// float or int32 (python)
assert(sizeof(datatype) == 4);

std::ifstream file;
file.open(path);
if (!file.is_open())
{
std::cout << "Cannot open " << path << "\n";
exit(1);
}
std::vector<datatype> batch(size);

file.read((char *)batch.data(), size * sizeof(float));
std::cout << " DONE\n";
return batch;
}

template <typename d_type>
static float
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K)
{
size_t correct = 0;
size_t total = 0;
//uncomment to test in parallel mode:


for (int i = 0; i < qsize; i++)
{

std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K);
total += K;
while (result.size())
{
if (answers[i].find(result.top().second) != answers[i].end())
{
correct++;
}
else
{
}
result.pop();
}
}
return 1.0f * correct / total;
}

static void
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
{
std::vector<size_t> efs = {1};
for (int i = k; i < 30; i++)
{
efs.push_back(i);
}
for (int i = 30; i < 400; i+=10)
{
efs.push_back(i);
}
for (int i = 1000; i < 100000; i += 5000)
{
efs.push_back(i);
}
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
for (size_t ef : efs)
{
appr_alg.setEf(ef);

appr_alg.metric_hops=0;
appr_alg.metric_distance_computations=0;
StopW stopw = StopW();

float recall = test_approx<float>(queries, qsize, appr_alg, vecdim, answers, k);
float time_us_per_query = stopw.getElapsedTimeMicro() / qsize;
float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize);
float hops_per_query = appr_alg.metric_hops / (1.0f * qsize);

std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
if (recall > 0.99)
{
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
break;
}
}
}

int main(int argc, char **argv)
{

int M = 16;
int efConstruction = 200;
int num_threads = std::thread::hardware_concurrency();



bool update = false;

if (argc == 2)
{
if (std::string(argv[1]) == "update")
{
update = true;
std::cout << "Updates are on\n";
}
else {
std::cout<<"Usage ./test_updates [update]\n";
exit(1);
}
}
else if (argc>2){
std::cout<<"Usage ./test_updates [update]\n";
exit(1);
}

std::string path = "../examples/data/";


int N;
int dummy_data_multiplier;
int N_queries;
int d;
int K;
{
std::ifstream configfile;
configfile.open(path + "/config.txt");
if (!configfile.is_open())
{
std::cout << "Cannot open config.txt\n";
return 1;
}
configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K;

printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K);
}

hnswlib::L2Space l2space(d);
hnswlib::HierarchicalNSW<float> appr_alg(&l2space, N + 1, M, efConstruction);

std::vector<float> dummy_batch = load_batch<float>(path + "batch_dummy_00.bin", N * d);

// Adding enterpoint:

appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0);

StopW stopw = StopW();

if (update)
{
std::cout << "Update iteration 0\n";


ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
});
appr_alg.checkIntegrity();

ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
});
appr_alg.checkIntegrity();

for (int b = 1; b < dummy_data_multiplier; b++)
{
std::cout << "Update iteration " << b << "\n";
char cpath[1024];
sprintf(cpath, "batch_dummy_%02d.bin", b);
std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);

ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
});
appr_alg.checkIntegrity();
}
}

std::cout << "Inserting final elements\n";
std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);

stopw.reset();
ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
});
std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
std::cout << "Running tests\n";
std::vector<float> queries_batch = load_batch<float>(path + "queries.bin", N_queries * d);

std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);

std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
for (int i = 0; i < N_queries; i++)
{
for (int j = 0; j < K; j++)
{
answers[i].insert(gt[i * K + j]);
}
}

for (int i = 0; i < 3; i++)
{
std::cout << "Test iteration " << i << "\n";
test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
}

return 0;
};
Loading