Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cpp tests for Windows in CI #409

Merged
merged 1 commit into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 15 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ jobs:
run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py"

test_cpp:
runs-on: ubuntu-latest
runs-on: ${{matrix.os}}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand All @@ -34,17 +37,27 @@ jobs:
mkdir build
cd build
cmake ..
make
if [ "$RUNNER_OS" == "Linux" ]; then
make
elif [ "$RUNNER_OS" == "Windows" ]; then
cmake --build ./ --config Release
fi
shell: bash

- name: Prepare test data
run: |
pip install numpy
cd examples
python update_gen_data.py
shell: bash

- name: Test
run: |
cd build
if [ "$RUNNER_OS" == "Windows" ]; then
cp ./Release/* ./
fi
./searchKnnCloserFirst_test
./test_updates
./test_updates update
shell: bash
25 changes: 16 additions & 9 deletions examples/updates_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "../hnswlib/hnswlib.h"
#include <thread>


class StopW
{
std::chrono::steady_clock::time_point time_begin;
Expand All @@ -22,6 +24,7 @@ class StopW
}
};


/*
* replacement for the openmp '#pragma omp parallel for' directive
* only handles a subset of functionality (no reductions etc)
Expand Down Expand Up @@ -81,8 +84,6 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn
std::rethrow_exception(lastException);
}
}


}


Expand All @@ -94,7 +95,7 @@ std::vector<datatype> load_batch(std::string path, int size)
assert(sizeof(datatype) == 4);

std::ifstream file;
file.open(path);
file.open(path, std::ios::binary);
if (!file.is_open())
{
std::cout << "Cannot open " << path << "\n";
Expand All @@ -107,6 +108,7 @@ std::vector<datatype> load_batch(std::string path, int size)
return batch;
}


template <typename d_type>
static float
test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
Expand Down Expand Up @@ -137,6 +139,7 @@ test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<
return 1.0f * correct / total;
}


static void
test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<float> &appr_alg, size_t vecdim,
std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t k)
Expand All @@ -155,6 +158,8 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
efs.push_back(i);
}
std::cout << "ef\trecall\ttime\thops\tdistcomp\n";

bool test_passed = false;
for (size_t ef : efs)
{
appr_alg.setEf(ef);
Expand All @@ -171,20 +176,24 @@ test_vs_recall(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalN
std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
if (recall > 0.99)
{
test_passed = true;
std::cout << "Recall is over 0.99! "<<recall << "\t" << time_us_per_query << "us \t"<<hops_per_query<<"\t"<<distance_comp_per_query << "\n";
break;
}
}
if (!test_passed)
{
std::cerr << "Test failed\n";
exit(1);
}
}


int main(int argc, char **argv)
{

int M = 16;
int efConstruction = 200;
int num_threads = std::thread::hardware_concurrency();



bool update = false;

Expand All @@ -207,7 +216,6 @@ int main(int argc, char **argv)

std::string path = "../examples/data/";


int N;
int dummy_data_multiplier;
int N_queries;
Expand Down Expand Up @@ -240,7 +248,6 @@ int main(int argc, char **argv)
if (update)
{
std::cout << "Update iteration 0\n";


ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
Expand Down Expand Up @@ -295,4 +302,4 @@ int main(int argc, char **argv)
}

return 0;
};
};