Skip to content

Commit

Permalink
add getting a list of all labels
Browse files Browse the repository at this point in the history
  • Loading branch information
yurymalkov committed Dec 5, 2018
1 parent dbb4f01 commit 5ba4c4c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
15 changes: 13 additions & 2 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class Index {
}
}

std::vector<std::vector<data_t>> GetDataReturnList(py::object ids_ = py::none()) {
std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
std::vector<size_t> ids;
if (!ids_.is_none()) {
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
Expand All @@ -235,6 +235,16 @@ class Index {
return data;
}

std::vector<unsigned int> getIdsList() {

std::vector<unsigned int> ids;

for(auto kv : appr_alg->label_lookup_) {
ids.push_back(kv.first);
}
return ids;
}

py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {

py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
Expand Down Expand Up @@ -360,7 +370,8 @@ PYBIND11_PLUGIN(hnswlib) {
py::arg("ef_construction")=200, py::arg("random_seed")=100)
.def("knn_query", &Index<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1)
.def("add_items", &Index<float>::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1)
.def("get_items", &Index<float, float>::GetDataReturnList, py::arg("ids") = py::none())
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
.def("get_ids_list", &Index<float>::getIdsList)
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))
.def("save_index", &Index<float>::saveIndex, py::arg("path_to_index"))
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/tests/bindings_test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def testRandomSelf(self):
diff_with_gt_labels=np.max(np.abs(data-items))
self.assertAlmostEqual(diff_with_gt_labels,0,1e-4)

# Checking that all labels are returned correcly:
sorted_labels=sorted(p.get_ids_list())
self.assertEqual(np.sum(~np.asarray(sorted_labels)==np.asarray(range(num_elements))),0)



Expand Down

0 comments on commit 5ba4c4c

Please sign in to comment.