diff --git a/contrib/inspect_tools.py b/contrib/inspect_tools.py index 1182156a82..87928f4bb9 100644 --- a/contrib/inspect_tools.py +++ b/contrib/inspect_tools.py @@ -68,6 +68,20 @@ def get_LinearTransform_matrix(pca): return A, b +def make_LinearTransform_matrix(A, b=None): + """ make a linear transform from a matrix and a bias term (optional)""" + d_out, d_in = A.shape + if b is not None: + assert b.shape == (d_out, ) + lt = faiss.LinearTransform(d_in, d_out, b is not None) + faiss.copy_array_to_vector(A.ravel(), lt.A) + if b is not None: + faiss.copy_array_to_vector(b, lt.b) + lt.is_trained = True + lt.set_is_orthonormal() + return lt + + def get_additive_quantizer_codebooks(aq): """ return to codebooks of an additive quantizer """ codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d) diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 5e635a53e8..f88907d397 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -37,6 +37,7 @@ set(FAISS_SRC IndexPQ.cpp IndexFastScan.cpp IndexAdditiveQuantizerFastScan.cpp + IndexIVFIndependentQuantizer.cpp IndexPQFastScan.cpp IndexPreTransform.cpp IndexRefine.cpp @@ -113,6 +114,7 @@ set(FAISS_HEADERS IndexIDMap.h IndexIVF.h IndexIVFAdditiveQuantizer.h + IndexIVFIndependentQuantizer.h IndexIVFFlat.h IndexIVFPQ.h IndexIVFFastScan.h diff --git a/faiss/IVFlib.cpp b/faiss/IVFlib.cpp index 88ac7c7a2f..8af652a103 100644 --- a/faiss/IVFlib.cpp +++ b/faiss/IVFlib.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -67,6 +68,10 @@ const IndexIVF* try_extract_index_ivf(const Index* index) { if (auto* idmap = dynamic_cast(index)) { index = idmap->index; } + if (auto* indep = + dynamic_cast(index)) { + index = indep->index_ivf; + } auto* ivf = dynamic_cast(index); diff --git a/faiss/IndexIVFIndependentQuantizer.cpp b/faiss/IndexIVFIndependentQuantizer.cpp new file mode 100644 index 0000000000..2073dd2ee4 --- /dev/null +++ b/faiss/IndexIVFIndependentQuantizer.cpp @@ -0,0 +1,172 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace faiss { + +IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer( + Index* quantizer, + IndexIVF* index_ivf, + VectorTransform* vt) + : Index(quantizer->d, index_ivf->metric_type), + quantizer(quantizer), + vt(vt), + index_ivf(index_ivf) { + if (vt) { + FAISS_THROW_IF_NOT_MSG( + vt->d_in == d && vt->d_out == index_ivf->d, + "invalid vector dimensions"); + } else { + FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions"); + } + + if (quantizer->is_trained && quantizer->ntotal != 0) { + FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist); + } + if (index_ivf->is_trained && vt) { + FAISS_THROW_IF_NOT(vt->is_trained); + } + ntotal = index_ivf->ntotal; + is_trained = + (quantizer->is_trained && quantizer->ntotal == index_ivf->nlist && + (!vt || vt->is_trained) && index_ivf->is_trained); + + // disable precomputed tables because they use the distances that are + // provided by the coarse quantizer (that are out of sync with the IVFPQ) + if (auto index_ivfpq = dynamic_cast(index_ivf)) { + index_ivfpq->use_precomputed_table = -1; + } +} + +IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() { + if (own_fields) { + delete quantizer; + delete index_ivf; + delete vt; + } +} + +namespace { + +struct VTransformedVectors : TransformedVectors { + VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x) + : TransformedVectors(x, vt ? vt->apply(n, x) : x) {} +}; + +struct SubsampledVectors : TransformedVectors { + SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x) + : TransformedVectors( + x, + fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {} +}; + +} // anonymous namespace + +void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) { + std::vector D(n); + std::vector I(n); + quantizer->search(n, x, 1, D.data(), I.data()); + + VTransformedVectors tv(vt, n, x); + + index_ivf->add_core(n, tv.x, nullptr, I.data()); +} + +void IndexIVFIndependentQuantizer::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported"); + int nprobe = index_ivf->nprobe; + std::vector D(n * nprobe); + std::vector I(n * nprobe); + quantizer->search(n, x, nprobe, D.data(), I.data()); + + VTransformedVectors tv(vt, n, x); + + index_ivf->search_preassigned( + n, tv.x, k, I.data(), D.data(), distances, labels, false); +} + +void IndexIVFIndependentQuantizer::reset() { + index_ivf->reset(); + ntotal = 0; +} + +void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) { + // quantizer training + size_t nlist = index_ivf->nlist; + Level1Quantizer l1(quantizer, nlist); + l1.train_q1(n, x, verbose, metric_type); + + // train the VectorTransform + if (vt && !vt->is_trained) { + if (verbose) { + printf("IndexIVFIndependentQuantizer: train the VectorTransform\n"); + } + vt->train(n, x); + } + + // get the centroids from the quantizer, transform them and + // add them to the index_ivf's quantizer + if (verbose) { + printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n"); + } + std::vector centroids(nlist * d); + quantizer->reconstruct_n(0, nlist, centroids.data()); + VTransformedVectors tcent(vt, nlist, centroids.data()); + + if (verbose) { + printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n"); + } + if (!index_ivf->quantizer->is_trained) { + index_ivf->quantizer->train(nlist, tcent.x); + } + index_ivf->quantizer->add(nlist, tcent.x); + + // train the payload + + // optional subsampling + idx_t max_nt = index_ivf->train_encoder_num_vectors(); + if (max_nt <= 0) { + max_nt = (size_t)1 << 35; + } + SubsampledVectors sv(index_ivf->d, &n, max_nt, x); + + // transform subsampled vectors + VTransformedVectors tv(vt, n, sv.x); + + if (verbose) { + printf("IndexIVFIndependentQuantizer: train encoder\n"); + } + + if (index_ivf->by_residual) { + // assign with quantizer + std::vector assign(n); + quantizer->assign(n, sv.x, assign.data()); + + // compute residual with IVF quantizer + std::vector residuals(n * index_ivf->d); + index_ivf->quantizer->compute_residual_n( + n, tv.x, residuals.data(), assign.data()); + + index_ivf->train_encoder(n, residuals.data(), assign.data()); + } else { + index_ivf->train_encoder(n, tv.x, nullptr); + } + index_ivf->is_trained = true; + is_trained = true; +} + +} // namespace faiss diff --git a/faiss/IndexIVFIndependentQuantizer.h b/faiss/IndexIVFIndependentQuantizer.h new file mode 100644 index 0000000000..4fe1666616 --- /dev/null +++ b/faiss/IndexIVFIndependentQuantizer.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +/** An IVF index with a quantizer that has a different input dimension from the + * payload size. The vectors to encode are obtained from the input vectors by a + * VectorTransform. + */ +struct IndexIVFIndependentQuantizer : Index { + /// quantizer is fed directly with the input vectors + Index* quantizer = nullptr; + + /// transform before the IVF vectors are applied + VectorTransform* vt = nullptr; + + /// the IVF index, controls nlist and nprobe + IndexIVF* index_ivf = nullptr; + + /// whether *this owns the 3 fields + bool own_fields = false; + + IndexIVFIndependentQuantizer( + Index* quantizer, + IndexIVF* index_ivf, + VectorTransform* vt = nullptr); + + IndexIVFIndependentQuantizer() {} + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; + + void reset() override; + + ~IndexIVFIndependentQuantizer() override; +}; + +} // namespace faiss diff --git a/faiss/IndexPreTransform.cpp b/faiss/IndexPreTransform.cpp index 3d9beb92ff..cde857c8ed 100644 --- a/faiss/IndexPreTransform.cpp +++ b/faiss/IndexPreTransform.cpp @@ -141,9 +141,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x) void IndexPreTransform::add(idx_t n, const float* x) { FAISS_THROW_IF_NOT(is_trained); - const float* xt = apply_chain(n, x); - ScopeDeleter del(xt == x ? nullptr : xt); - index->add(n, xt); + TransformedVectors tv(x, apply_chain(n, x)); + index->add(n, tv.x); ntotal = index->ntotal; } @@ -152,9 +151,8 @@ void IndexPreTransform::add_with_ids( const float* x, const idx_t* xids) { FAISS_THROW_IF_NOT(is_trained); - const float* xt = apply_chain(n, x); - ScopeDeleter del(xt == x ? nullptr : xt); - index->add_with_ids(n, xt, xids); + TransformedVectors tv(x, apply_chain(n, x)); + index->add_with_ids(n, tv.x, xids); ntotal = index->ntotal; } @@ -190,10 +188,9 @@ void IndexPreTransform::range_search( RangeSearchResult* result, const SearchParameters* params) const { FAISS_THROW_IF_NOT(is_trained); - const float* xt = apply_chain(n, x); - ScopeDeleter del(xt == x ? nullptr : xt); + TransformedVectors tv(x, apply_chain(n, x)); index->range_search( - n, xt, radius, result, extract_index_search_params(params)); + n, tv.x, radius, result, extract_index_search_params(params)); } void IndexPreTransform::reset() { @@ -238,14 +235,13 @@ void IndexPreTransform::search_and_reconstruct( FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(is_trained); - const float* xt = apply_chain(n, x); - ScopeDeleter del((xt == x) ? nullptr : xt); + TransformedVectors trans(x, apply_chain(n, x)); float* recons_temp = chain.empty() ? recons : new float[n * k * index->d]; ScopeDeleter del2((recons_temp == recons) ? nullptr : recons_temp); index->search_and_reconstruct( n, - xt, + trans.x, k, distances, labels, @@ -262,13 +258,8 @@ size_t IndexPreTransform::sa_code_size() const { void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { - if (chain.empty()) { - index->sa_encode(n, x, bytes); - } else { - const float* xt = apply_chain(n, x); - ScopeDeleter del(xt == x ? nullptr : xt); - index->sa_encode(n, xt, bytes); - } + TransformedVectors tv(x, apply_chain(n, x)); + index->sa_encode(n, tv.x, bytes); } void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x) diff --git a/faiss/VectorTransform.cpp b/faiss/VectorTransform.cpp index 252f4aa874..b0421bef9f 100644 --- a/faiss/VectorTransform.cpp +++ b/faiss/VectorTransform.cpp @@ -441,13 +441,10 @@ void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) { } // namespace -void PCAMatrix::train(idx_t n, const float* x) { - const float* x_in = x; - - x = fvecs_maybe_subsample( - d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose); - - ScopeDeleter del_x(x != x_in ? x : nullptr); +void PCAMatrix::train(idx_t n, const float* x_in) { + const float* x = fvecs_maybe_subsample( + d_in, (size_t*)&n, max_points_per_d * d_in, x_in, verbose); + TransformedVectors tv(x_in, x); // compute mean mean.clear(); @@ -884,14 +881,13 @@ ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca) is_trained = false; } -void ITQTransform::train(idx_t n, const float* x) { +void ITQTransform::train(idx_t n, const float* x_in) { FAISS_THROW_IF_NOT(!is_trained); - const float* x_in = x; size_t max_train_points = std::max(d_in * max_train_per_dim, 32768); - x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x); - - ScopeDeleter del_x(x != x_in ? x : nullptr); + const float* x = + fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x_in); + TransformedVectors tv(x_in, x); std::unique_ptr x_norm(new float[n * d_in]); { // normalize @@ -988,25 +984,16 @@ void ITQTransform::check_identical(const VectorTransform& other_in) const { *********************************************/ OPQMatrix::OPQMatrix(int d, int M, int d2) - : LinearTransform(d, d2 == -1 ? d : d2, false), - M(M), - niter(50), - niter_pq(4), - niter_pq_0(40), - verbose(false), - pq(nullptr) { + : LinearTransform(d, d2 == -1 ? d : d2, false), M(M) { is_trained = false; // OPQ is quite expensive to train, so set this right. max_train_points = 256 * 256; - pq = nullptr; } -void OPQMatrix::train(idx_t n, const float* x) { - const float* x_in = x; - - x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose); - - ScopeDeleter del_x(x != x_in ? x : nullptr); +void OPQMatrix::train(idx_t n, const float* x_in) { + const float* x = fvecs_maybe_subsample( + d_in, (size_t*)&n, max_train_points, x_in, verbose); + TransformedVectors tv(x_in, x); // To support d_out > d_in, we pad input vectors with 0s to d_out size_t d = d_out <= d_in ? d_in : d_out; diff --git a/faiss/VectorTransform.h b/faiss/VectorTransform.h index c233bfae57..55e46e81d5 100644 --- a/faiss/VectorTransform.h +++ b/faiss/VectorTransform.h @@ -230,18 +230,18 @@ struct ProductQuantizer; * */ struct OPQMatrix : LinearTransform { - int M; ///< nb of subquantizers - int niter; ///< Number of outer training iterations - int niter_pq; ///< Number of training iterations for the PQ - int niter_pq_0; ///< same, for the first outer iteration + int M; ///< nb of subquantizers + int niter = 50; ///< Number of outer training iterations + int niter_pq = 4; ///< Number of training iterations for the PQ + int niter_pq_0 = 40; ///< same, for the first outer iteration /// if there are too many training points, resample - size_t max_train_points; - bool verbose; + size_t max_train_points = 256 * 256; + bool verbose = false; /// if non-NULL, use this product quantizer for training /// should be constructed with (d_out, M, _) - ProductQuantizer* pq; + ProductQuantizer* pq = nullptr; /// if d2 != -1, output vectors of this dimension explicit OPQMatrix(int d = 0, int M = 1, int d2 = -1); diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 600d8df52b..423b22a9cc 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -860,7 +861,22 @@ Index* read_index(IOReader* f, int io_flags) { h == fourcc("IvPQ") || h == fourcc("IvQR") || h == fourcc("IwPQ") || h == fourcc("IwQR")) { idx = read_ivfpq(f, h, io_flags); - + } else if (h == fourcc("IwIQ")) { + auto* indep = new IndexIVFIndependentQuantizer(); + indep->own_fields = true; + read_index_header(indep, f); + indep->quantizer = read_index(f, io_flags); + bool has_vt; + READ1(has_vt); + if (has_vt) { + indep->vt = read_VectorTransform(f); + } + indep->index_ivf = dynamic_cast(read_index(f, io_flags)); + FAISS_THROW_IF_NOT(indep->index_ivf); + if (auto index_ivfpq = dynamic_cast(indep->index_ivf)) { + READ1(index_ivfpq->use_precomputed_table); + } + idx = indep; } else if (h == fourcc("IxPT")) { IndexPreTransform* ixpt = new IndexPreTransform(); ixpt->own_fields = true; diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index d40f651c56..84484e799c 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -702,7 +703,22 @@ void write_index(const Index* idx, IOWriter* f) { WRITEVECTOR(ivfpqr->refine_codes); WRITE1(ivfpqr->k_factor); } - + } else if ( + auto* indep = + dynamic_cast(idx)) { + uint32_t h = fourcc("IwIQ"); + WRITE1(h); + write_index_header(indep, f); + write_index(indep->quantizer, f); + bool has_vt = indep->vt != nullptr; + WRITE1(has_vt); + if (has_vt) { + write_VectorTransform(indep->vt, f); + } + write_index(indep->index_ivf, f); + if (auto index_ivfpq = dynamic_cast(indep->index_ivf)) { + WRITE1(index_ivfpq->use_precomputed_table); + } } else if ( const IndexPreTransform* ixpt = dynamic_cast(idx)) { diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index c0f0f9456d..27fb63bd90 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -194,6 +194,9 @@ def replacement_function(*args): add_ref_in_constructor(IDSelectorXOr, slice(2)) add_ref_in_constructor(IDSelectorTranslated, slice(2)) +add_ref_in_constructor(IDSelectorXOr, slice(2)) +add_ref_in_constructor(IndexIVFIndependentQuantizer, slice(3)) + # seems really marginal... # remove_ref_from_method(IndexReplicas, 'removeIndex', 0) diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index b49dcd80e2..4b187e5991 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -76,6 +76,7 @@ typedef uint64_t size_t; #include #include #include +#include #include #include @@ -137,7 +138,7 @@ typedef uint64_t size_t; #include #include #include -#include +#include #include @@ -470,6 +471,8 @@ void gpu_sync_all_devices() %include %include %include +%include + %include %include @@ -587,6 +590,7 @@ void gpu_sync_all_devices() DOWNCAST ( IndexShardsIVF ) DOWNCAST2 ( IndexShards, IndexShardsTemplateT_faiss__Index_t ) DOWNCAST2 ( IndexReplicas, IndexReplicasTemplateT_faiss__Index_t ) + DOWNCAST ( IndexIVFIndependentQuantizer) DOWNCAST ( IndexIVFPQR ) DOWNCAST ( IndexIVFPQ ) DOWNCAST ( IndexIVFPQFastScan ) @@ -656,6 +660,7 @@ void gpu_sync_all_devices() } } + %typemap(out) faiss::IndexBinary * { DOWNCAST2 ( IndexBinaryReplicas, IndexReplicasTemplateT_faiss__IndexBinary_t ) DOWNCAST2 ( IndexBinaryIDMap2, IndexIDMap2TemplateT_faiss__IndexBinary_t ) diff --git a/tests/test_contrib.py b/tests/test_contrib.py index ad85a3ddd0..cfeee6397b 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -208,6 +208,16 @@ def test_IndexFlat(self): xb, inspect_tools.get_flat_data(index) ) + def test_make_LT(self): + rs = np.random.RandomState(123) + X = rs.rand(13, 20).astype('float32') + A = rs.rand(5, 20).astype('float32') + b = rs.rand(5).astype('float32') + Yref = X @ A.T + b + lt = inspect_tools.make_LinearTransform_matrix(A, b) + Ynew = lt.apply(X) + np.testing.assert_equal(Yref, Ynew) + class TestRangeEval(unittest.TestCase): diff --git a/tests/test_index_composite.py b/tests/test_index_composite.py index ea6691cfd4..d4f99b92d0 100644 --- a/tests/test_index_composite.py +++ b/tests/test_index_composite.py @@ -15,6 +15,8 @@ import platform from common_faiss_tests import get_dataset_2 +from faiss.contrib.datasets import SyntheticDataset +from faiss.contrib.inspect_tools import make_LinearTransform_matrix class TestRemoveFastScan(unittest.TestCase): @@ -721,3 +723,117 @@ def test_Flat_subset_type_3(self): def test_Flat_subset_type_4(self): self.do_test("IVF30,Flat", subset_type=4) + + +class TestIndependentQuantizer(unittest.TestCase): + + def test_sidebyside(self): + """ provide double-sized vectors to the index, where each vector + is the concatenation of twice the same vector """ + ds = SyntheticDataset(32, 1000, 500, 50) + + index = faiss.index_factory(ds.d, "IVF32,SQ8") + index.train(ds.get_train()) + index.add(ds.get_database()) + index.nprobe = 4 + Dref, Iref = index.search(ds.get_queries(), 10) + + select32first = make_LinearTransform_matrix( + np.eye(64, dtype='float32')[:32]) + + select32last = make_LinearTransform_matrix( + np.eye(64, dtype='float32')[32:]) + + quantizer = faiss.IndexPreTransform( + select32first, + index.quantizer + ) + + index2 = faiss.IndexIVFIndependentQuantizer( + quantizer, + index, select32last + ) + + xq2 = np.hstack([ds.get_queries()] * 2) + quantizer.search(xq2, 30) + Dnew, Inew = index2.search(xq2, 10) + + np.testing.assert_array_equal(Dref, Dnew) + np.testing.assert_array_equal(Iref, Inew) + + # test add + index2.reset() + xb2 = np.hstack([ds.get_database()] * 2) + index2.add(xb2) + Dnew, Inew = index2.search(xq2, 10) + + np.testing.assert_array_equal(Dref, Dnew) + np.testing.assert_array_equal(Iref, Inew) + + def test_half_store(self): + """ the index stores only the first half of each vector + but the coarse quantizer sees them entirely """ + ds = SyntheticDataset(32, 1000, 500, 50) + gt = ds.get_groundtruth(10) + + select32first = make_LinearTransform_matrix( + np.eye(32, dtype='float32')[:16]) + + index_ivf = faiss.index_factory(ds.d // 2, "IVF32,Flat") + index_ivf.nprobe = 4 + index = faiss.IndexPreTransform(select32first, index_ivf) + index.train(ds.get_train()) + index.add(ds.get_database()) + + Dref, Iref = index.search(ds.get_queries(), 10) + perf_ref = faiss.eval_intersection(Iref, gt) + + index_ivf = faiss.index_factory(ds.d // 2, "IVF32,Flat") + index_ivf.nprobe = 4 + index = faiss.IndexIVFIndependentQuantizer( + faiss.IndexFlatL2(ds.d), + index_ivf, select32first + ) + index.train(ds.get_train()) + index.add(ds.get_database()) + + Dnew, Inew = index.search(ds.get_queries(), 10) + perf_new = faiss.eval_intersection(Inew, gt) + + self.assertLess(perf_ref, perf_new) + + def test_precomputed_tables(self): + """ see how precomputed tables behave with centroid distance estimates from a mismatching + coarse quantizer """ + ds = SyntheticDataset(48, 2000, 500, 250) + gt = ds.get_groundtruth(10) + + index = faiss.IndexIVFIndependentQuantizer( + faiss.IndexFlatL2(48), + faiss.index_factory(16, "IVF64,PQ4np"), + faiss.PCAMatrix(48, 16) + ) + index.train(ds.get_train()) + index.add(ds.get_database()) + + index_ivf = faiss.downcast_index(faiss.extract_index_ivf(index)) + index_ivf.nprobe = 4 + + Dref, Iref = index.search(ds.get_queries(), 10) + perf_ref = faiss.eval_intersection(Iref, gt) + + index_ivf.use_precomputed_table = 1 + index_ivf.precompute_table() + + Dnew, Inew = index.search(ds.get_queries(), 10) + perf_new = faiss.eval_intersection(Inew, gt) + + # to be honest, it is not clear which one is better... + self.assertNotEqual(perf_ref, perf_new) + + # check IO while we are at it + index2 = faiss.deserialize_index(faiss.serialize_index(index)) + D2, I2 = index2.search(ds.get_queries(), 10) + + np.testing.assert_array_equal(Dnew, D2) + np.testing.assert_array_equal(Inew, I2)