diff --git a/contrib/inspect_tools.py b/contrib/inspect_tools.py index 87928f4bb9..cc22ff5368 100644 --- a/contrib/inspect_tools.py +++ b/contrib/inspect_tools.py @@ -96,3 +96,16 @@ def get_flat_data(index): """ copy and return the data matrix in an IndexFlat """ xb = faiss.vector_to_array(index.codes).view("float32") return xb.reshape(index.ntotal, index.d) + + +def get_NSG_neighbors(nsg): + """ get the neighbor list for the vectors stored in the NSG structure, as + a N-by-K matrix of indices """ + graph = nsg.get_final_graph() + neighbors = np.zeros((graph.N, graph.K), dtype='int32') + faiss.memcpy( + faiss.swig_ptr(neighbors), + graph.data, + neighbors.nbytes + ) + return neighbors diff --git a/faiss/impl/NNDescent.cpp b/faiss/impl/NNDescent.cpp index cc84b5d609..8878349ff6 100644 --- a/faiss/impl/NNDescent.cpp +++ b/faiss/impl/NNDescent.cpp @@ -374,6 +374,10 @@ void NNDescent::init_graph(DistanceComputer& qdis) { void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) { FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build"); + FAISS_THROW_IF_NOT_FMT( + n > NUM_EVAL_POINTS, + "NNDescent.build cannot build a graph smaller than %d", + int(NUM_EVAL_POINTS)); if (verbose) { printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n", @@ -403,7 +407,7 @@ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) { has_built = true; if (verbose) { - printf("Addes %d points into the index\n", ntotal); + printf("Added %d points into the index\n", ntotal); } } diff --git a/faiss/impl/NSG.h b/faiss/impl/NSG.h index e115b317fb..641a42f8cf 100644 --- a/faiss/impl/NSG.h +++ b/faiss/impl/NSG.h @@ -54,7 +54,7 @@ namespace nsg { template struct Graph { - node_t* data; ///< the flattened adjacency matrix + node_t* data; ///< the flattened adjacency matrix, size N-by-K int K; ///< nb of neighbors per node int N; ///< total nb of nodes bool own_fields; ///< the underlying data owned by itself or not diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 7ebc6624e5..852690622b 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -454,7 +454,20 @@ void gpu_sync_all_devices() %include %include + +%warnfilter(509) faiss::nsg::Graph< int >::at(int,int); + %include + +%template(NSG_Graph_int) faiss::nsg::Graph; + +// not using %shared_ptr to avoid mem leaks +%extend faiss::NSG { + faiss::nsg::Graph* get_final_graph() { + return $self->final_graph.get(); + } +} + %include #ifndef SWIGWIN diff --git a/tests/test_build_blocks.py b/tests/test_build_blocks.py index 77ca92623b..77f022adf8 100644 --- a/tests/test_build_blocks.py +++ b/tests/test_build_blocks.py @@ -542,6 +542,23 @@ def subtest(self, d, K, metric): print('Metric: {}, knng accuracy: {}'.format(metric_names[metric], recall)) assert recall > 0.99 + def test_small_nndescent(self): + """ building a too small graph used to crash, make sure it raises + an exception instead. + TODO: build the exact knn graph for small cases + """ + d = 32 + K = 10 + index = faiss.IndexNNDescentFlat(d, K, faiss.METRIC_L2) + index.nndescent.S = 10 + index.nndescent.R = 32 + index.nndescent.L = K + 20 + index.nndescent.iter = 5 + index.verbose = True + + xb = np.zeros((78, d), dtype='float32') + self.assertRaises(RuntimeError, index.add, xb) + class TestResultHeap(unittest.TestCase): diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 1982241142..057b043573 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -219,6 +219,16 @@ def test_make_LT(self): Ynew = lt.apply(X) np.testing.assert_equal(Yref, Ynew) + def test_NSG_neighbors(self): + # FIXME number of elements to add should be >> 100 + ds = datasets.SyntheticDataset(32, 0, 200, 10) + index = faiss.index_factory(ds.d, "NSG") + index.add(ds.get_database()) + neighbors = inspect_tools.get_NSG_neighbors(index.nsg) + # neighbors should be either valid indexes or -1 + np.testing.assert_array_less(-2, neighbors) + np.testing.assert_array_less(neighbors, ds.nb) + class TestRangeEval(unittest.TestCase):