diff --git a/include/svs/index/inverted/clustering.h b/include/svs/index/inverted/clustering.h index db82e1593..de5585c29 100644 --- a/include/svs/index/inverted/clustering.h +++ b/include/svs/index/inverted/clustering.h @@ -571,6 +571,36 @@ template class Clustering { // Saving and Loading. static constexpr lib::Version save_version{0, 0, 0}; static constexpr std::string_view serialization_schema = "clustering"; + + lib::SaveTable metadata() const { + return lib::SaveTable( + serialization_schema, + save_version, + {{"integer_type", lib::save(datatype_v)}, + {"num_clusters", lib::save(size())}} + ); + } + + void save(std::ostream& os) const { + for (const auto& [id, cluster] : *this) { + cluster.serialize(os); + } + } + + static Clustering + load(const lib::ContextFreeLoadTable& table, std::istream& stream) { + auto saved_integer_type = lib::load_at(table, "integer_type"); + if (saved_integer_type != datatype_v) { + throw ANNEXCEPTION("Clustering was saved using {} but we're trying to reload it using {}!", saved_integer_type, datatype_v); + } + auto num_clusters = lib::load_at(table, "num_clusters"); + auto clustering = Clustering(); + for (size_t i = 0; i < num_clusters; ++i) { + clustering.insert(Cluster::deserialize(stream)); + } + return clustering; + } + lib::SaveTable save(const lib::SaveContext& ctx) const { // Serialize all clusters into an auxiliary file. auto fullpath = ctx.generate_name("clustering", "bin"); @@ -582,48 +612,28 @@ template class Clustering { } } - return lib::SaveTable( - serialization_schema, - save_version, - {{"filepath", lib::save(fullpath.filename())}, - SVS_LIST_SAVE(filesize), - {"integer_type", lib::save(datatype_v)}, - {"num_clusters", lib::save(size())}} - ); + auto table = metadata(); + table.insert("filepath", lib::save(fullpath.filename())); + table.insert("filesize", lib::save(filesize)); + return table; + + return table; } static Clustering load(const lib::LoadTable& table) { - // Ensure we have the correct integer type when decoding. - auto saved_integer_type = lib::load_at(table, "integer_type"); - if (saved_integer_type != datatype_v) { - auto type = datatype_v; + auto expected_filesize = lib::load_at(table, "filesize"); + auto file = table.resolve_at("filepath"); + size_t actual_filesize = std::filesystem::file_size(file); + if (actual_filesize != expected_filesize) { throw ANNEXCEPTION( - "Clustering was saved using {} but we're trying to reload it using {}!", - saved_integer_type, - type + "Expected cluster file size to be {}. Instead, it is {}!", + actual_filesize, + expected_filesize ); } - auto num_clusters = lib::load_at(table, "num_clusters"); - auto expected_filesize = lib::load_at(table, "filesize"); - auto clustering = Clustering(); - { - auto file = table.resolve_at("filepath"); - size_t actual_filesize = std::filesystem::file_size(file); - if (actual_filesize != expected_filesize) { - throw ANNEXCEPTION( - "Expected cluster file size to be {}. Instead, it is {}!", - actual_filesize, - expected_filesize - ); - } - - auto io = lib::open_read(file); - for (size_t i = 0; i < num_clusters; ++i) { - clustering.insert(Cluster::deserialize(io)); - } - } - return clustering; + auto io = lib::open_read(file); + return load(table, io); } private: diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index 3d3fc24c5..8b2e8c8e1 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -497,6 +497,8 @@ template class InvertedIndex { index_.save(index_config, graph, data); } + void save_primary_index(std::ostream& os) const { index_.save(os); } + ///// Accessors /// @brief Getter method for logger svs::logging::logger_ptr get_logger() const { return logger_; } @@ -655,4 +657,46 @@ auto assemble_from_clustering( ); } +template < + typename DataProto, + typename Distance, + StorageStrategy Strategy, + typename ThreadPoolProto> +auto assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + Strategy strategy, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() +) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool); + auto clustering = lib::load_from_stream>(is); + auto ids = clustering.sorted_centroids(); + + // skip magic + svs::lib::detail::Deserializer::build(is); + auto index = index::vamana::auto_assemble( + is, + lib::Lazy([&]() { return GraphLoader::return_type::load(is); }), + lib::Lazy([&]() { + using T = typename std::decay_t::element_type; + constexpr size_t Ext = std::decay_t::extent; + return lib::load_from_stream>(is); + }), + distance, + 1, + logger + ); + + return InvertedIndex( + std::move(index), + strategy(original, clustering, HugepageAllocator()), + std::move(ids), + std::move(threadpool), + std::move(logger) + ); +} + } // namespace svs::index::inverted diff --git a/include/svs/orchestrators/inverted.h b/include/svs/orchestrators/inverted.h index 6b6e50470..7fb7e60da 100644 --- a/include/svs/orchestrators/inverted.h +++ b/include/svs/orchestrators/inverted.h @@ -34,6 +34,9 @@ class InvertedInterface { const std::filesystem::path& primary_data, const std::filesystem::path& primary_graph ) = 0; + + ///// Saving + virtual void save_primary_index(std::ostream& os) = 0; }; template @@ -72,6 +75,8 @@ class InvertedImpl : public manager::ManagerImpl { ) override { impl().save_primary_index(primary_config, primary_data, primary_graph); } + + void save_primary_index(std::ostream& os) override { impl().save_primary_index(os); } }; ///// @@ -106,6 +111,8 @@ class Inverted : public manager::IndexManager { impl_->save_primary_index(primary_config, primary_data, primary_graph); } + void save_primary_index(std::ostream& os) { impl_->save_primary_index(os); } + ///// Building template < manager::QueryTypeDefinition QueryTypes, @@ -168,6 +175,30 @@ class Inverted : public manager::IndexManager { std::move(threadpool_proto) )}; } + template < + manager::QueryTypeDefinition QueryTypes, + typename DataProto, + typename Distance, + typename ThreadPoolProto, + typename StorageStrategy = index::inverted::SparseStrategy> + static Inverted assemble_from_clustering( + std::istream& is, + DataProto data_proto, + Distance distance, + ThreadPoolProto threadpool_proto, + StorageStrategy strategy = {} + ) { + return Inverted{ + std::in_place, + manager::as_typelist{}, + index::inverted::assemble_from_clustering( + is, + std::move(data_proto), + std::move(distance), + std::move(strategy), + std::move(threadpool_proto) + )}; + } }; } // namespace svs diff --git a/tests/svs/index/inverted/memory_based.cpp b/tests/svs/index/inverted/memory_based.cpp index d418dd83d..28ff6b269 100644 --- a/tests/svs/index/inverted/memory_based.cpp +++ b/tests/svs/index/inverted/memory_based.cpp @@ -19,9 +19,11 @@ #include "spdlog/sinks/callback_sink.h" #include "svs-benchmark/datasets.h" #include "svs/lib/timing.h" +#include "svs/orchestrators/inverted.h" #include "tests/utils/inverted_reference.h" #include "tests/utils/test_dataset.h" #include +#include CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { // Vector to store captured log messages @@ -73,3 +75,76 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") { CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos); CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos); } + +namespace { +constexpr size_t NUM_NEIGHBORS = 10; + +template void test_stream_save_load(Strategy strategy) { + auto distance = svs::DistanceL2(); + constexpr auto distance_type = svs::distance_type_v; + auto expected_results = test_dataset::inverted::expected_build_results( + distance_type, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto build_parameters = expected_results.build_parameters_.value(); + + // Capture the clustering during build. + svs::index::inverted::Clustering clustering; + auto clustering_op = [&](const auto& c) { clustering = c; }; + + svs::Inverted index = svs::Inverted::build( + build_parameters, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy, + svs::index::inverted::PickRandomly{}, + clustering_op + ); + + auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + auto parameters = index.get_search_parameters(); + auto results = index.search(queries, NUM_NEIGHBORS); + + // Serialize to stream. + std::stringstream ss; + svs::lib::save_to_stream(clustering, ss); + index.save_primary_index(ss); + + // Load from stream. + svs::Inverted loaded = svs::Inverted::assemble_from_clustering>( + ss, + svs::data::SimpleData::load(test_dataset::data_svs_file()), + distance, + 2, + strategy + ); + loaded.set_search_parameters(parameters); + + // Compare basic properties. + CATCH_REQUIRE(loaded.size() == index.size()); + CATCH_REQUIRE(loaded.dimensions() == index.dimensions()); + + // Compare search results element-wise. + auto loaded_results = loaded.search(queries, NUM_NEIGHBORS); + CATCH_REQUIRE(loaded_results.n_queries() == results.n_queries()); + CATCH_REQUIRE(loaded_results.n_neighbors() == results.n_neighbors()); + for (size_t q = 0; q < results.n_queries(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } +} +} // namespace + +CATCH_TEST_CASE("InvertedIndex Save and Load", "[saveload][inverted][index]") { + CATCH_SECTION("SparseStrategy") { + test_stream_save_load(svs::index::inverted::SparseStrategy()); + } + CATCH_SECTION("DenseStrategy") { + test_stream_save_load(svs::index::inverted::DenseStrategy()); + } +}