diff --git a/faiss/invlists/BlockInvertedLists.cpp b/faiss/invlists/BlockInvertedLists.cpp index 6370d11871..dbdb0302dc 100644 --- a/faiss/invlists/BlockInvertedLists.cpp +++ b/faiss/invlists/BlockInvertedLists.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries( codes[list_no].resize(n_block * block_size); if (o % block_size == 0) { // copy whole blocks - memcpy(&codes[list_no][o * code_size], code, n_block * block_size); + memcpy(&codes[list_no][o * packer->code_size], + code, + n_block * block_size); } else { FAISS_THROW_IF_NOT_MSG(packer, "missing code packer"); std::vector buffer(packer->code_size); @@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const { return codes[list_no].get(); } +size_t BlockInvertedLists::remove_ids(const IDSelector& sel) { + idx_t nremove = 0; +#pragma omp parallel for + for (idx_t i = 0; i < nlist; i++) { + std::vector buffer(packer->code_size); + idx_t l = ids[i].size(), j = 0; + while (j < l) { + if (sel.is_member(ids[i][j])) { + l--; + ids[i][j] = ids[i][l]; + packer->unpack_1(codes[i].data(), l, buffer.data()); + packer->pack_1(buffer.data(), j, codes[i].data()); + } else { + j++; + } + } + resize(i, l); + nremove += ids[i].size() - l; + } + + return nremove; +} + const idx_t* BlockInvertedLists::get_ids(size_t list_no) const { assert(list_no < nlist); return ids[list_no].data(); @@ -102,12 +128,6 @@ void BlockInvertedLists::update_entries( const idx_t*, const uint8_t*) { FAISS_THROW_MSG("not impemented"); - /* - assert (list_no < nlist); - assert (n_entry + offset <= ids[list_no].size()); - memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry); - memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry); - */ } BlockInvertedLists::~BlockInvertedLists() { diff --git a/faiss/invlists/BlockInvertedLists.h b/faiss/invlists/BlockInvertedLists.h index 8d8df720bf..2b9cbba455 100644 --- a/faiss/invlists/BlockInvertedLists.h +++ b/faiss/invlists/BlockInvertedLists.h @@ -15,6 +15,7 @@ namespace faiss { struct CodePacker; +struct IDSelector; /** Inverted Lists that are organized by blocks. * @@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists { size_t list_size(size_t list_no) const override; const uint8_t* get_codes(size_t list_no) const override; const idx_t* get_ids(size_t list_no) const override; + /// remove ids from the InvertedLists + size_t remove_ids(const IDSelector& sel); // works only on empty BlockInvertedLists // the codes should be of size ceil(n_entry / n_per_block) * block_size diff --git a/faiss/invlists/DirectMap.cpp b/faiss/invlists/DirectMap.cpp index 2b272922d5..dc2b92aa1c 100644 --- a/faiss/invlists/DirectMap.cpp +++ b/faiss/invlists/DirectMap.cpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace faiss { @@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) { std::vector toremove(nlist); size_t nremove = 0; - + BlockInvertedLists* block_invlists = + dynamic_cast(invlists); if (type == NoMap) { + if (block_invlists != nullptr) { + return block_invlists->remove_ids(sel); + } // exhaustive scan of IVF #pragma omp parallel for for (idx_t i = 0; i < nlist; i++) { @@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) { } } } else if (type == Hashtable) { + FAISS_THROW_IF_MSG( + block_invlists, + "remove with hashtable is not supported with BlockInvertedLists"); const IDSelectorArray* sela = dynamic_cast(&sel); FAISS_THROW_IF_NOT_MSG( diff --git a/tests/test_merge_index.py b/tests/test_merge_index.py index 8c4c1f0912..4417f57fe7 100644 --- a/tests/test_merge_index.py +++ b/tests/test_merge_index.py @@ -246,19 +246,45 @@ def test_merge_IDMap2(self): class TestRemoveFastScan(unittest.TestCase): - def do_fast_scan_test(self, factory_key, size1): + def do_fast_scan_test(self, + factory_key, + with_ids=False, + direct_map_type=faiss.DirectMap.NoMap): ds = SyntheticDataset(110, 1000, 1000, 100) - index1 = faiss.index_factory(ds.d, factory_key) - index1.train(ds.get_train()) - index1.reset() + index = faiss.index_factory(ds.d, factory_key) + index.train(ds.get_train()) + + index.reset() tokeep = [i % 3 == 0 for i in range(ds.nb)] - index1.add(ds.get_database()[tokeep]) - _, Iref = index1.search(ds.get_queries(), 5) - index1.reset() - index1.add(ds.get_database()) - index1.remove_ids(np.where(np.logical_not(tokeep))[0]) - _, Inew = index1.search(ds.get_queries(), 5) + if with_ids: + index.add_with_ids(ds.get_database()[tokeep], np.arange(ds.nb)[tokeep]) + faiss.extract_index_ivf(index).nprobe = 5 + else: + index.add(ds.get_database()[tokeep]) + _, Iref = index.search(ds.get_queries(), 5) + + index.reset() + if with_ids: + index.add_with_ids(ds.get_database(), np.arange(ds.nb)) + index.set_direct_map_type(direct_map_type) + faiss.extract_index_ivf(index).nprobe = 5 + else: + index.add(ds.get_database()) + index.remove_ids(np.where(np.logical_not(tokeep))[0]) + _, Inew = index.search(ds.get_queries(), 5) np.testing.assert_array_equal(Inew, Iref) - def test_remove(self): - self.do_fast_scan_test("PQ5x4fs", 320) + def test_remove_PQFastScan(self): + # with_ids is not support for this type of index + self.do_fast_scan_test("PQ5x4fs", False) + + def test_remove_IVFPQFastScan(self): + self.do_fast_scan_test("IVF20,PQ5x4fs", True) + + def test_remove_IVFPQFastScan_2(self): + self.assertRaisesRegex(Exception, + ".*not supported.*", + self.do_fast_scan_test, + "IVF20,PQ5x4fs", + True, + faiss.DirectMap.Hashtable)