From 14b8af6e736fdfff636584841e61e0161d8ceadd Mon Sep 17 00:00:00 2001 From: Junjie Qi Date: Mon, 25 Mar 2024 11:19:40 -0700 Subject: [PATCH] Fix IVFPQFastScan decode function (#3312) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3312 as the [#issue3258](https://github.com/facebookresearch/faiss/issues/3258) mentioned, the IVFPQFastScan should have same decoding result as IVFPQ. However, current result is not as expected. In this PR/Diff, we are going to fix the decoding function Reviewed By: mdouze Differential Revision: D55264781 fbshipit-source-id: dfdae9eabceadfc5a3ebb851930d71ce3c1c654d --- faiss/IndexIVF.h | 8 +++++ faiss/IndexIVFPQFastScan.cpp | 23 +++++++++++-- tests/test_fast_scan_ivf.py | 63 +++++++++++++++++++++++++++--------- 3 files changed, 77 insertions(+), 17 deletions(-) diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index 45c65ef839..185561d086 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -433,6 +433,14 @@ struct IndexIVF : Index, IndexIVFInterface { /* The standalone codec interface (except sa_decode that is specific) */ size_t sa_code_size() const override; + + /** encode a set of vectors + * sa_encode will call encode_vector with include_listno=true + * @param n nb of vectors to encode + * @param x the vectors to encode + * @param bytes output array for the codes + * @return nb of bytes written to codes + */ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override; IndexIVF(); diff --git a/faiss/IndexIVFPQFastScan.cpp b/faiss/IndexIVFPQFastScan.cpp index d069db1354..2844ae4936 100644 --- a/faiss/IndexIVFPQFastScan.cpp +++ b/faiss/IndexIVFPQFastScan.cpp @@ -286,9 +286,28 @@ void IndexIVFPQFastScan::compute_LUT( } } -void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x) +void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x) const { - pq.decode(bytes, x, n); + size_t coarse_size = coarse_code_size(); + +#pragma omp parallel if (n > 1) + { + std::vector residual(d); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + const uint8_t* code = codes + i * (code_size + coarse_size); + int64_t list_no = decode_listno(code); + float* xi = x + i * d; + pq.decode(code + coarse_size, xi); + if (by_residual) { + quantizer->reconstruct(list_no, residual.data()); + for (size_t j = 0; j < d; j++) { + xi[j] += residual[j]; + } + } + } + } } } // namespace faiss diff --git a/tests/test_fast_scan_ivf.py b/tests/test_fast_scan_ivf.py index d6dad8fec3..f48dd2e47a 100644 --- a/tests/test_fast_scan_ivf.py +++ b/tests/test_fast_scan_ivf.py @@ -84,9 +84,7 @@ def sp(x): b = btab[0] dis_new = self.compute_dis_quant(codes, LUTq, biasq, a, b) - # print(a, b, dis_ref.sum()) avg_realtive_error = np.abs(dis_new - dis_ref).sum() / dis_ref.sum() - # print('a=', a, 'avg_relative_error=', avg_realtive_error) self.assertLess(avg_realtive_error, 0.0005) def test_no_residual_ip(self): @@ -228,8 +226,6 @@ def eval_quant_loss(self, by_residual, metric=faiss.METRIC_L2): m3 = three_metrics(Da, Ia, Db, Ib) - - # print(by_residual, metric, recall_at_1, recall_at_10, intersection_at_10) ref_results = { (True, 1): [0.985, 1.0, 9.872], (True, 0): [ 0.987, 1.0, 9.914], @@ -261,6 +257,7 @@ class TestEquivPQ(unittest.TestCase): def test_equiv_pq(self): ds = datasets.SyntheticDataset(32, 2000, 200, 4) + xq = ds.get_queries() index = faiss.index_factory(32, "IVF1,PQ16x4np") index.by_residual = False @@ -268,7 +265,7 @@ def test_equiv_pq(self): index.quantizer.add(np.zeros((1, 32), dtype='float32')) index.train(ds.get_train()) index.add(ds.get_database()) - Dref, Iref = index.search(ds.get_queries(), 4) + Dref, Iref = index.search(xq, 4) index_pq = faiss.index_factory(32, "PQ16x4np") index_pq.pq = index.pq @@ -276,21 +273,64 @@ def test_equiv_pq(self): index_pq.codes = faiss. downcast_InvertedLists( index.invlists).codes.at(0) index_pq.ntotal = index.ntotal - Dnew, Inew = index_pq.search(ds.get_queries(), 4) + Dnew, Inew = index_pq.search(xq, 4) np.testing.assert_array_equal(Iref, Inew) np.testing.assert_array_equal(Dref, Dnew) index_pq2 = faiss.IndexPQFastScan(index_pq) index_pq2.implem = 12 - Dref, Iref = index_pq2.search(ds.get_queries(), 4) + Dref, Iref = index_pq2.search(xq, 4) index2 = faiss.IndexIVFPQFastScan(index) index2.implem = 12 - Dnew, Inew = index2.search(ds.get_queries(), 4) + Dnew, Inew = index2.search(xq, 4) np.testing.assert_array_equal(Iref, Inew) np.testing.assert_array_equal(Dref, Dnew) + # test encode and decode + + np.testing.assert_array_equal( + index_pq.sa_encode(xq), + index2.sa_encode(xq) + ) + + np.testing.assert_array_equal( + index_pq.sa_decode(index_pq.sa_encode(xq)), + index2.sa_decode(index2.sa_encode(xq)) + ) + + np.testing.assert_array_equal( + ((index_pq.sa_decode(index_pq.sa_encode(xq)) - xq) ** 2).sum(1), + ((index2.sa_decode(index2.sa_encode(xq)) - xq) ** 2).sum(1) + ) + + def test_equiv_pq_encode_decode(self): + ds = datasets.SyntheticDataset(32, 1000, 200, 10) + xq = ds.get_queries() + + index_ivfpq = faiss.index_factory(ds.d, "IVF10,PQ8x4np") + index_ivfpq.train(ds.get_train()) + + index_ivfpqfs = faiss.IndexIVFPQFastScan(index_ivfpq) + + np.testing.assert_array_equal( + index_ivfpq.sa_encode(xq), + index_ivfpqfs.sa_encode(xq) + ) + + np.testing.assert_array_equal( + index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)), + index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) + ) + + np.testing.assert_array_equal( + ((index_ivfpq.sa_decode(index_ivfpq.sa_encode(xq)) - xq) ** 2) + .sum(1), + ((index_ivfpqfs.sa_decode(index_ivfpqfs.sa_encode(xq)) - xq) ** 2) + .sum(1) + ) + class TestIVFImplem12(unittest.TestCase): @@ -463,7 +503,6 @@ def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32): Dnew, Inew = index2.search(ds.get_queries(), 10) m3 = three_metrics(Dref, Iref, Dnew, Inew) - # print((by_residual, metric, d), ":", m3) ref_m3_tab = { (True, 1, 32): (0.995, 1.0, 9.91), (True, 0, 32): (0.99, 1.0, 9.91), @@ -554,7 +593,6 @@ def subtest_accuracy(self, aq, st, by_residual, implem, metric_type='L2'): recall_ref = (Iref == gt).sum() / nq recall1 = (I1 == gt).sum() / nq - print(aq, st, by_residual, implem, metric_type, recall_ref, recall1) assert abs(recall_ref - recall1) < 0.051 def xx_test_accuracy(self): @@ -599,7 +637,6 @@ def subtest_rescale_accuracy(self, aq, st, by_residual, implem): recall_ref = (Iref == gt).sum() / nq recall1 = (I1 == gt).sum() / nq - print(aq, st, by_residual, implem, recall_ref, recall1) assert abs(recall_ref - recall1) < 0.05 def xx_test_rescale_accuracy(self): @@ -624,7 +661,6 @@ def subtest_from_ivfaq(self, implem): nq = Iref.shape[0] recall_ref = (Iref == gt).sum() / nq recall1 = (I1 == gt).sum() / nq - print(recall_ref, recall1) assert abs(recall_ref - recall1) < 0.02 def test_from_ivfaq(self): @@ -763,7 +799,6 @@ def subtest_accuracy(self, paq): recall_ref = (Iref == gt).sum() / nq recall1 = (I1 == gt).sum() / nq - print(paq, recall_ref, recall1) assert abs(recall_ref - recall1) < 0.05 def test_accuracy_PLSQ(self): @@ -847,7 +882,6 @@ def do_test(self, metric=faiss.METRIC_L2): # find a reasonable radius D, I = index.search(ds.get_queries(), 10) radius = np.median(D[:, -1]) - # print("radius=", radius) lims1, D1, I1 = index.range_search(ds.get_queries(), radius) index2 = faiss.IndexIVFPQFastScan(index) @@ -860,7 +894,6 @@ def do_test(self, metric=faiss.METRIC_L2): for i in range(ds.nq): ref = set(I1[lims1[i]: lims1[i + 1]]) new = set(I2[lims2[i]: lims2[i + 1]]) - print(ref, new) nmiss += len(ref - new) nextra += len(new - ref)