From 82352dd4537afb8e8688f7332584e618fdd1d8e8 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 1 Sep 2023 02:37:33 -0700 Subject: [PATCH] make nbits configurable for graph indices based on PQ (#3031) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3031 As requested in https://github.com/facebookresearch/faiss/issues/3027 Indeed, PQ sizes with nbits > 8 are good tradeoffs, so it is interesting to support them. Reviewed By: pemazare Differential Revision: D48860659 fbshipit-source-id: 6f3c642e0902e1523bef36db6be3af3688d529a5 --- faiss/IndexHNSW.cpp | 16 ++++------------ faiss/IndexHNSW.h | 8 ++++---- faiss/IndexNSG.cpp | 22 +++------------------- faiss/IndexNSG.h | 20 ++++++++++---------- faiss/index_factory.cpp | 15 +++++++++------ tests/test_factory.py | 9 +++++++++ 6 files changed, 39 insertions(+), 51 deletions(-) diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index 78787753e1..52e4315c90 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -250,18 +250,10 @@ void hnsw_add_vertices( **************************************************************/ IndexHNSW::IndexHNSW(int d, int M, MetricType metric) - : Index(d, metric), - hnsw(M), - own_fields(false), - storage(nullptr), - reconstruct_from_neighbors(nullptr) {} + : Index(d, metric), hnsw(M) {} IndexHNSW::IndexHNSW(Index* storage, int M) - : Index(storage->d, storage->metric_type), - hnsw(M), - own_fields(false), - storage(storage), - reconstruct_from_neighbors(nullptr) {} + : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {} IndexHNSW::~IndexHNSW() { if (own_fields) { @@ -886,8 +878,8 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric) IndexHNSWPQ::IndexHNSWPQ() {} -IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M) - : IndexHNSW(new IndexPQ(d, pq_m, 8), M) { +IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits) + : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) { own_fields = true; is_trained = false; } diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index f1ff609e94..13855d3037 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -74,10 +74,10 @@ struct IndexHNSW : Index { HNSW hnsw; // the sequential storage - bool own_fields; - Index* storage; + bool own_fields = false; + Index* storage = nullptr; - ReconstructFromNeighbors* reconstruct_from_neighbors; + ReconstructFromNeighbors* reconstruct_from_neighbors = nullptr; explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2); explicit IndexHNSW(Index* storage, int M = 32); @@ -152,7 +152,7 @@ struct IndexHNSWFlat : IndexHNSW { */ struct IndexHNSWPQ : IndexHNSW { IndexHNSWPQ(); - IndexHNSWPQ(int d, int pq_m, int M); + IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8); void train(idx_t n, const float* x) override; }; diff --git a/faiss/IndexNSG.cpp b/faiss/IndexNSG.cpp index a7cfd490a4..654b788186 100644 --- a/faiss/IndexNSG.cpp +++ b/faiss/IndexNSG.cpp @@ -29,32 +29,16 @@ using namespace nsg; * IndexNSG implementation **************************************************************/ -IndexNSG::IndexNSG(int d, int R, MetricType metric) - : Index(d, metric), - nsg(R), - own_fields(false), - storage(nullptr), - is_built(false), - GK(64), - build_type(0) { - nndescent_S = 10; - nndescent_R = 100; +IndexNSG::IndexNSG(int d, int R, MetricType metric) : Index(d, metric), nsg(R) { nndescent_L = GK + 50; - nndescent_iter = 10; } IndexNSG::IndexNSG(Index* storage, int R) : Index(storage->d, storage->metric_type), nsg(R), - own_fields(false), storage(storage), - is_built(false), - GK(64), build_type(1) { - nndescent_S = 10; - nndescent_R = 100; nndescent_L = GK + 50; - nndescent_iter = 10; } IndexNSG::~IndexNSG() { @@ -304,8 +288,8 @@ IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric) IndexNSGPQ::IndexNSGPQ() {} -IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M) - : IndexNSG(new IndexPQ(d, pq_m, 8), M) { +IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M, int pq_nbits) + : IndexNSG(new IndexPQ(d, pq_m, pq_nbits), M) { own_fields = true; is_trained = false; } diff --git a/faiss/IndexNSG.h b/faiss/IndexNSG.h index 21c8239dc2..172b10c980 100644 --- a/faiss/IndexNSG.h +++ b/faiss/IndexNSG.h @@ -28,25 +28,25 @@ struct IndexNSG : Index { NSG nsg; /// the sequential storage - bool own_fields; - Index* storage; + bool own_fields = false; + Index* storage = nullptr; /// the index is built or not - bool is_built; + bool is_built = false; /// K of KNN graph for building - int GK; + int GK = 64; /// indicate how to build a knn graph /// - 0: build NSG with brute force search /// - 1: build NSG with NNDescent - char build_type; + char build_type = 0; /// parameters for nndescent - int nndescent_S; - int nndescent_R; - int nndescent_L; - int nndescent_iter; + int nndescent_S = 10; + int nndescent_R = 100; + int nndescent_L; // set to GK + 50 + int nndescent_iter = 10; explicit IndexNSG(int d = 0, int R = 32, MetricType metric = METRIC_L2); explicit IndexNSG(Index* storage, int R = 32); @@ -90,7 +90,7 @@ struct IndexNSGFlat : IndexNSG { */ struct IndexNSGPQ : IndexNSG { IndexNSGPQ(); - IndexNSGPQ(int d, int pq_m, int M); + IndexNSGPQ(int d, int pq_m, int M, int pq_nbits = 8); void train(idx_t n, const float* x) override; }; diff --git a/faiss/index_factory.cpp b/faiss/index_factory.cpp index 5dbf6094c4..5d7a505e09 100644 --- a/faiss/index_factory.cpp +++ b/faiss/index_factory.cpp @@ -440,11 +440,13 @@ IndexHNSW* parse_IndexHNSW( if (match("Flat|")) { return new IndexHNSWFlat(d, hnsw_M, mt); } - if (match("PQ([0-9]+)(np)?")) { + + if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) { int M = std::stoi(sm[1].str()); - IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M); + int nbit = mres_to_int(sm[2], 8, 1); + IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M, nbit); dynamic_cast(ipq->storage)->do_polysemous_training = - sm[2].str() != "np"; + sm[3].str() != "np"; return ipq; } if (match(sq_pattern)) { @@ -490,11 +492,12 @@ IndexNSG* parse_IndexNSG( if (match("Flat|")) { return new IndexNSGFlat(d, nsg_R, mt); } - if (match("PQ([0-9]+)(np)?")) { + if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) { int M = std::stoi(sm[1].str()); - IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R); + int nbit = mres_to_int(sm[2], 8, 1); + IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R, nbit); dynamic_cast(ipq->storage)->do_polysemous_training = - sm[2].str() != "np"; + sm[3].str() != "np"; return ipq; } if (match(sq_pattern)) { diff --git a/tests/test_factory.py b/tests/test_factory.py index 511de0f499..79662af757 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -82,6 +82,9 @@ def test_factory_HNSW_newstyle(self): index = faiss.index_factory(12, "HNSW32,PQ4np") indexpq = faiss.downcast_index(index.storage) assert not indexpq.do_polysemous_training + index = faiss.index_factory(12, "HNSW32,PQ4x12np") + indexpq = faiss.downcast_index(index.storage) + self.assertEqual(indexpq.pq.nbits, 12) def test_factory_NSG(self): index = faiss.index_factory(12, "NSG64") @@ -97,6 +100,12 @@ def test_factory_NSG(self): assert isinstance(index, faiss.IndexNSGFlat) assert index.nsg.R == 64 + index = faiss.index_factory(12, "NSG64,PQ3x10") + assert isinstance(index, faiss.IndexNSGPQ) + assert index.nsg.R == 64 + indexpq = faiss.downcast_index(index.storage) + self.assertEqual(indexpq.pq.nbits, 10) + index = faiss.index_factory(12, "IVF65536_NSG64,Flat") index_nsg = faiss.downcast_index(index.quantizer) assert isinstance(index, faiss.IndexIVFFlat)