diff --git a/include/svs/index/flat/flat.h b/include/svs/index/flat/flat.h index d9b91edf..5deac303 100644 --- a/include/svs/index/flat/flat.h +++ b/include/svs/index/flat/flat.h @@ -17,6 +17,7 @@ #pragma once // Flat index utilities +#include "svs/core/logging.h" #include "svs/index/flat/inserters.h" #include "svs/index/index.h" @@ -145,6 +146,8 @@ class FlatIndex { data_storage_type data_; [[no_unique_address]] distance_type distance_; threads::ThreadPoolHandle threadpool_; + // SVS logger for per index logging + svs::logging::logger_ptr logger_; // Constructs controlling the iteration strategy over the data and queries. search_parameters_type search_parameters_{}; @@ -171,6 +174,9 @@ class FlatIndex { } public: + /// @brief Getter method for logger + svs::logging::logger_ptr get_logger() const { return logger_; } + search_parameters_type get_search_parameters() const { return search_parameters_; } void set_search_parameters(const search_parameters_type& search_parameters) { @@ -189,22 +195,35 @@ class FlatIndex { /// instance or an integer specifying the number of threads to use. In the latter /// case, a new default thread pool will be constructed using ``threadpool_proto`` /// as the number of threads to create. + /// @param logger_ Spd logger for per-index logging customization. /// /// @copydoc threadpool_requirements /// template - FlatIndex(Data data, Dist distance, ThreadPoolProto threadpool_proto) + FlatIndex( + Data data, + Dist distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() + ) requires std::is_same_v : data_{std::move(data)} , distance_{std::move(distance)} - , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} {} + , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , logger_{std::move(logger)} {} template - FlatIndex(Data& data, Dist distance, ThreadPoolProto threadpool_proto) + FlatIndex( + Data& data, + Dist distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() + ) requires std::is_same_v : data_{data} , distance_{std::move(distance)} - , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} {} + , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , logger_{std::move(logger)} {} ////// Dataset Interface @@ -462,6 +481,7 @@ class FlatIndex { /// instance or an integer specifying the number of threads to use. In the latter case, /// a new default thread pool will be constructed using ``threadpool_proto`` as the /// number of threads to create. +/// @param logger_ Spd logger for per-index logging customization. /// /// This method provides much of the heavy lifting for constructing a Flat index from /// a data file on disk or a dataset in memory. @@ -472,11 +492,16 @@ class FlatIndex { /// template auto auto_assemble( - DataProto&& data_proto, Distance distance, ThreadPoolProto threadpool_proto + DataProto&& data_proto, + Distance distance, + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); auto data = svs::detail::dispatch_load(std::forward(data_proto), threadpool); - return FlatIndex(std::move(data), std::move(distance), std::move(threadpool)); + return FlatIndex( + std::move(data), std::move(distance), std::move(threadpool), std::move(logger) + ); } /// @brief Alias for a short-lived flat index. diff --git a/include/svs/index/inverted/clustering.h b/include/svs/index/inverted/clustering.h index 2039b456..bb80e9bb 100644 --- a/include/svs/index/inverted/clustering.h +++ b/include/svs/index/inverted/clustering.h @@ -801,7 +801,8 @@ Clustering cluster_with( const Data& data, std::span centroid_ids, const ClusteringParameters& params, - Index& primary_index + Index& primary_index, + svs::logging::logger_ptr logger = svs::logging::get() ) { for (auto id : centroid_ids) { if (id >= data.size()) { @@ -820,7 +821,6 @@ Clustering cluster_with( size_t start = 0; size_t datasize = data.size(); auto timer = lib::Timer(); - auto logger = svs::logging::get(); while (start < datasize) { size_t stop = std::min(start + batchsize, datasize); diff --git a/include/svs/index/inverted/memory_based.h b/include/svs/index/inverted/memory_based.h index 2dc49d99..18e128b4 100644 --- a/include/svs/index/inverted/memory_based.h +++ b/include/svs/index/inverted/memory_based.h @@ -339,12 +339,17 @@ template class InvertedIndex { template InvertedIndex( - Index index, Cluster cluster, translator_type index_local_to_global, Pool threadpool + Index index, + Cluster cluster, + translator_type index_local_to_global, + Pool threadpool, + svs::logging::logger_ptr logger = svs::logging::get() ) : index_{std::move(index)} , cluster_{std::move(cluster)} , index_local_to_global_{std::move(index_local_to_global)} - , threadpool_{std::move(threadpool)} { + , threadpool_{std::move(threadpool)} + , logger_{std::move(logger)} { // Clear out the threadpool in the inner index - prefer to handle threading // ourselves. index_.set_threadpool(threads::SequentialThreadPool()); @@ -492,6 +497,10 @@ template class InvertedIndex { index_.save(index_config, graph, data); } + ///// Accessors + /// @brief Getter method for logger + svs::logging::logger_ptr get_logger() const { return logger_; } + private: // Tunable Parameters double refinement_epsilon_ = 10.0; @@ -503,6 +512,9 @@ template class InvertedIndex { // Transient parameters. threads::ThreadPoolHandle threadpool_; + + // SVS logger for per index logging + svs::logging::logger_ptr logger_; }; struct PickRandomly { @@ -548,7 +560,8 @@ auto auto_build( // Customizations Strategy strategy = {}, CentroidPicker centroid_picker = {}, - ClusteringOp clustering_op = {} + ClusteringOp clustering_op = {}, + svs::logging::logger_ptr logger = svs::logging::get() ) { // Perform clustering. auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); @@ -569,7 +582,11 @@ auto auto_build( // Cluster the dataset with the help of the primary index. auto clustering = cluster_with( - data, lib::as_const_span(centroids), parameters.clustering_parameters_, index + data, + lib::as_const_span(centroids), + parameters.clustering_parameters_, + index, + logger ); // Perform any post-proceseccing on the clustering. @@ -585,7 +602,8 @@ auto auto_build( std::move(index), strategy(data, clustering, HugepageAllocator()), std::move(centroids), - std::move(primary_threadpool)}; + std::move(primary_threadpool), + std::move(logger)}; } ///// Auto Assembling. @@ -601,7 +619,8 @@ auto assemble_from_clustering( Strategy strategy, const std::filesystem::path& index_config, const std::filesystem::path& graph, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool); @@ -621,7 +640,8 @@ auto assemble_from_clustering( return local_data; }), distance, - 1 + 1, + logger ); // Create the clustering and return the final results. @@ -629,7 +649,8 @@ auto assemble_from_clustering( std::move(index), strategy(original, clustering, HugepageAllocator()), std::move(ids), - std::move(threadpool) + std::move(threadpool), + std::move(logger) ); } diff --git a/include/svs/index/vamana/calibrate.h b/include/svs/index/vamana/calibrate.h index 7cd1f86b..c6cf4147 100644 --- a/include/svs/index/vamana/calibrate.h +++ b/include/svs/index/vamana/calibrate.h @@ -176,9 +176,9 @@ VamanaSearchParameters optimize_split_buffer( double target_recall, VamanaSearchParameters current, const F& compute_recall, - const DoSearch& do_search + const DoSearch& do_search, + svs::logging::logger_ptr logger = svs::logging::get() ) { - auto logger = svs::logging::get(); svs::logging::trace(logger, "Entering split buffer optimization routine"); assert( current.buffer_config_.get_search_window_size() == @@ -252,11 +252,11 @@ std::pair optimize_search_buffer( size_t num_neighbors, double target_recall, const ComputeRecall& compute_recall, - const DoSearch& do_search + const DoSearch& do_search, + svs::logging::logger_ptr logger = svs::logging::get() ) { using enum CalibrationParameters::SearchBufferOptimization; using dataset_type = typename Index::data_type; - auto logger = svs::logging::get(); double max_recall = std::numeric_limits::lowest(); const size_t current_capacity = current.buffer_config_.get_total_capacity(); @@ -345,9 +345,9 @@ VamanaSearchParameters tune_prefetch( const CalibrationParameters& calibration_parameters, Index& index, VamanaSearchParameters search_parameters, - const DoSearch& do_search + const DoSearch& do_search, + svs::logging::logger_ptr logger = svs::logging::get() ) { - auto logger = svs::logging::get(); svs::logging::trace(logger, "Tuning prefetch parameters"); const auto& prefetch_steps = calibration_parameters.prefetch_steps_; size_t max_lookahead = index.max_degree(); diff --git a/include/svs/index/vamana/dynamic_index.h b/include/svs/index/vamana/dynamic_index.h index 2ee36d08..6a37778b 100644 --- a/include/svs/index/vamana/dynamic_index.h +++ b/include/svs/index/vamana/dynamic_index.h @@ -157,6 +157,9 @@ class MutableVamanaIndex { float alpha_ = 1.2; bool use_full_search_history_ = true; + // SVS logger for per index logging + svs::logging::logger_ptr logger_; + // Methods public: // Constructors @@ -167,7 +170,9 @@ class MutableVamanaIndex { Idx entry_point, Dist distance_function, const ExternalIds& external_ids, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + // Optional logger parameter + svs::logging::logger_ptr logger = svs::logging::get() ) : graph_{std::move(graph)} , data_{std::move(data)} @@ -178,7 +183,9 @@ class MutableVamanaIndex { , distance_{std::move(distance_function)} , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} , search_parameters_{vamana::construct_default_search_parameters(data_)} - , construction_window_size_{2 * graph.max_degree()} { + , construction_window_size_{2 * graph.max_degree()} + // Ctor accept logger in parameter + , logger_{std::move(logger)} { translator_.insert(external_ids, threads::UnitRange(0, external_ids.size())); } @@ -191,7 +198,8 @@ class MutableVamanaIndex { Data data, const ExternalIds& external_ids, Dist distance_function, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) : graph_(Graph{data.size(), parameters.graph_max_degree}) , data_(std::move(data)) @@ -206,7 +214,8 @@ class MutableVamanaIndex { , max_candidates_(parameters.max_candidate_pool_size) , prune_to_(parameters.prune_to) , alpha_(parameters.alpha) - , use_full_search_history_{parameters.use_full_search_history} { + , use_full_search_history_{parameters.use_full_search_history} + , logger_{std::move(logger)} { // Setup the initial translation of external to internal ids. translator_.insert(external_ids, threads::UnitRange(0, external_ids.size())); @@ -220,8 +229,8 @@ class MutableVamanaIndex { auto builder = VamanaBuilder( graph_, data_, distance_, parameters, threadpool_, prefetch_parameters ); - builder.construct(1.0f, entry_point_[0]); - builder.construct(parameters.alpha, entry_point_[0]); + builder.construct(1.0f, entry_point_[0], logging::Level::Info, logger_); + builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger_); } /// @brief Post re-load constructor. @@ -240,7 +249,8 @@ class MutableVamanaIndex { graph_type graph, const Dist& distance_function, IDTranslator translator, - Pool threadpool + Pool threadpool, + svs::logging::logger_ptr logger = svs::logging::get() ) : graph_{std::move(graph)} , data_{std::move(data)} @@ -255,7 +265,8 @@ class MutableVamanaIndex { , max_candidates_{config.build_parameters.max_candidate_pool_size} , prune_to_{config.build_parameters.prune_to} , alpha_{config.build_parameters.alpha} - , use_full_search_history_{config.build_parameters.use_full_search_history} {} + , use_full_search_history_{config.build_parameters.use_full_search_history} + , logger_{std::move(logger)} {} ///// Scratchspace scratchspace_type scratchspace(const search_parameters_type& sp) const { @@ -272,6 +283,8 @@ class MutableVamanaIndex { scratchspace_type scratchspace() const { return scratchspace(get_search_parameters()); } ///// Accessors + /// @brief Getter method for logger + svs::logging::logger_ptr get_logger() const { return logger_; } /// @brief Get the alpha value used for pruning while mutating the graph. float get_alpha() const { return alpha_; } @@ -1200,6 +1213,17 @@ template MutableVamanaIndex, Data, Dist>; +// Guide with logging +template +MutableVamanaIndex( + const VamanaBuildParameters&, + Data, + const ExternalIds&, + Dist, + Pool, + svs::logging::logger_ptr +) -> MutableVamanaIndex, Data, Dist>; + namespace detail { struct VamanaStateLoader { @@ -1251,7 +1275,8 @@ auto auto_dynamic_assemble( // to easily benchmark the static versus dynamic implementation. // // This is an internal API and should not be considered officially supported nor stable. - bool debug_load_from_static = false + bool debug_load_from_static = false, + svs::logging::logger_ptr logger = svs::logging::get() ) { // Load the dataset auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); @@ -1317,7 +1342,8 @@ auto auto_dynamic_assemble( std::move(graph), std::move(distance), std::move(translator), - std::move(threadpool)}; + std::move(threadpool), + std::move(logger)}; } } // namespace svs::index::vamana diff --git a/include/svs/index/vamana/index.h b/include/svs/index/vamana/index.h index d89ac537..a50ce11d 100644 --- a/include/svs/index/vamana/index.h +++ b/include/svs/index/vamana/index.h @@ -302,6 +302,8 @@ class VamanaIndex { lib::ReadWriteProtected default_search_parameters_{}; // Construction parameters VamanaBuildParameters build_parameters_{}; + // SVS logger for per index logging + svs::logging::logger_ptr logger_; public: /// The type of the search resource used for external threading. @@ -326,6 +328,7 @@ class VamanaIndex { /// instance or an integer specifying the number of threads to use. In the latter /// case, a new default thread pool will be constructed using ``threadpool_proto`` /// as the number of threads to create. + /// @param logger_ Spd logger for per-index logging customization. /// /// This is a lower-level function that is meant to take a collection of /// instantiated components and assemble the final index. For a more "hands-free" @@ -346,14 +349,16 @@ class VamanaIndex { Data data, Idx entry_point, Dist distance_function, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) : graph_{std::move(graph)} , data_{std::move(data)} , entry_point_{entry_point} , distance_{std::move(distance_function)} , threadpool_{threads::as_threadpool(std::move(threadpool_proto))} - , default_search_parameters_{construct_default_search_parameters(data_)} {} + , default_search_parameters_{construct_default_search_parameters(data_)} + , logger_{std::move(logger)} {} /// /// @brief Build a VamanaIndex over the given dataset. @@ -366,6 +371,7 @@ class VamanaIndex { /// @param distance_function The distance function used to compare queries and /// elements of the dataset. /// @param threadpool The acceptable threadpool to use to conduct searches. + /// @param logger_ Spd logger for per-index logging customization. /// /// This is a lower-level function that is meant to take a dataset and construct /// the graph-based index over the dataset. For a more "hands-free" approach, see @@ -385,14 +391,16 @@ class VamanaIndex { Data data, Idx entry_point, Dist distance_function, - Pool threadpool + Pool threadpool, + svs::logging::logger_ptr logger = svs::logging::get() ) : VamanaIndex{ std::move(graph), std::move(data), entry_point, std::move(distance_function), - std::move(threadpool)} { + std::move(threadpool), + logger} { if (graph_.n_nodes() != data_.size()) { throw ANNEXCEPTION("Wrong sizes!"); } @@ -407,10 +415,13 @@ class VamanaIndex { extensions::estimate_prefetch_parameters(data_) ); - builder.construct(1.0F, entry_point_[0]); - builder.construct(parameters.alpha, entry_point_[0]); + builder.construct(1.0F, entry_point_[0], logging::Level::Info, logger); + builder.construct(parameters.alpha, entry_point_[0], logging::Level::Info, logger); } + /// @brief Getter method for logger + svs::logging::logger_ptr get_logger() const { return logger_; } + /// @brief Apply the given configuration parameters to the index. void apply(const VamanaIndexParameters& parameters) { entry_point_.clear(); @@ -863,6 +874,7 @@ class VamanaIndex { /// a new default thread pool will be constructed using ``threadpool_proto`` as the /// number of threads to create. /// @param graph_allocator The allocator to use for the graph data structure. +/// @param logger_ Spd logger for per-index logging customization. /// /// @copydoc threadpool_requirements /// @@ -876,7 +888,8 @@ auto auto_build( DataProto data_proto, Distance distance, ThreadPoolProto threadpool_proto, - const Allocator& graph_allocator = {} + const Allocator& graph_allocator = {}, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); auto data = svs::detail::dispatch_load(std::move(data_proto), threadpool); @@ -891,7 +904,8 @@ auto auto_build( std::move(data), lib::narrow(entry_point), std::move(distance), - std::move(threadpool)}; + std::move(threadpool), + logger}; } /// @@ -909,6 +923,7 @@ auto auto_build( /// This method provides much of the heavy lifting for instantiating a Vamana index from /// a collection of files on disk (or perhaps a mix-and-match of existing data in-memory /// and on disk). +/// @param logger_ Spd logger for per-index logging customization. /// /// Refer to the examples for use of this interface. /// @@ -924,7 +939,8 @@ auto auto_assemble( GraphProto graph_loader, DataProto data_proto, Distance distance, - ThreadPoolProto threadpool_proto + ThreadPoolProto threadpool_proto, + svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); auto data = svs::detail::dispatch_load(std::move(data_proto), threadpool); @@ -933,8 +949,12 @@ auto auto_assemble( // Extract the index type of the provided graph. using I = typename decltype(graph)::index_type; auto index = VamanaIndex{ - std::move(graph), std::move(data), I{}, std::move(distance), std::move(threadpool)}; - + std::move(graph), + std::move(data), + I{}, + std::move(distance), + std::move(threadpool), + std::move(logger)}; auto config = lib::load_from_disk(config_path); index.apply(config); return index; diff --git a/include/svs/index/vamana/vamana_build.h b/include/svs/index/vamana/vamana_build.h index a174ac0a..b20f7bc5 100644 --- a/include/svs/index/vamana/vamana_build.h +++ b/include/svs/index/vamana/vamana_build.h @@ -202,9 +202,15 @@ class VamanaBuilder { } } - void - construct(float alpha, Idx entry_point, logging::Level level = logging::Level::Info) { - construct(alpha, entry_point, threads::UnitRange{0, data_.size()}, level); + void construct( + float alpha, + Idx entry_point, + logging::Level level = logging::Level::Info, + logging::logger_ptr logger = svs::logging::get() + ) { + construct( + alpha, entry_point, threads::UnitRange{0, data_.size()}, level, logger + ); } template @@ -212,9 +218,9 @@ class VamanaBuilder { float alpha, Idx entry_point, const R& range, - logging::Level level = logging::Level::Info + logging::Level level = logging::Level::Info, + logging::logger_ptr logger = svs::logging::get() ) { - auto logger = svs::logging::get(); size_t num_nodes = range.size(); size_t num_batches = std::max( size_t{40}, lib::div_round_up(num_nodes, lib::narrow_cast(64 * 64)) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bbd8557f..b8352697 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -125,6 +125,7 @@ set(TEST_SOURCES # Index Specific Functionality ${TEST_DIR}/svs/index/index.cpp ${TEST_DIR}/svs/index/flat/inserters.cpp + ${TEST_DIR}/svs/index/flat/flat.cpp ${TEST_DIR}/svs/index/vamana/build_parameters.cpp ${TEST_DIR}/svs/index/vamana/consolidate.cpp ${TEST_DIR}/svs/index/vamana/filter.cpp @@ -138,6 +139,7 @@ set(TEST_SOURCES ${TEST_DIR}/svs/index/vamana/iterator.cpp # Inverted ${TEST_DIR}/svs/index/inverted/clustering.cpp + ${TEST_DIR}/svs/index/inverted/memory_based.cpp # # ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp ) diff --git a/tests/integration/vamana/index_build.cpp b/tests/integration/vamana/index_build.cpp index 637eae33..ec04b156 100644 --- a/tests/integration/vamana/index_build.cpp +++ b/tests/integration/vamana/index_build.cpp @@ -40,6 +40,10 @@ #include #include +// logger +#include "spdlog/sinks/callback_sink.h" +#include "svs/core/logging.h" + namespace { template < @@ -199,3 +203,75 @@ CATCH_TEST_CASE( ++num_threads; } } + +// Helper function to create a logger with a callback +std::shared_ptr create_test_logger(std::vector& captured_logs +) { + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); + + auto logger = std::make_shared("test_logger", callback_sink); + logger->set_level(spdlog::level::trace); + return logger; +} + +CATCH_TEST_CASE("VamanaIndex Logging Tests", "[logging]") { + using namespace svs::index::vamana; + + // Test data setup + std::vector data = {1.0f, 2.0f}; + const size_t dim = 1; + auto graph = svs::graphs::SimpleGraph(1, 64); + auto data_view = svs::data::SimpleDataView(data.data(), 1, dim); + svs::distance::DistanceL2 distance_function; + uint32_t entry_point = 0; + auto threadpool = svs::threads::DefaultThreadPool(1); + VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + + CATCH_SECTION("With Custom Logger") { + std::vector captured_logs; + auto custom_logger = create_test_logger(captured_logs); + + // Create VamanaIndex, which will call the builder and construct + VamanaIndex vamana_index( + buildParams, + std::move(graph), + std::move(data_view), + entry_point, + std::move(distance_function), + std::move(threadpool), + custom_logger + ); + + // Verify the custom logger captured the log messages + CATCH_REQUIRE(captured_logs[0].find("Number of syncs:") != std::string::npos); + CATCH_REQUIRE(captured_logs[1].find("Batch Size:") != std::string::npos); + auto default_logger = svs::logging::get(); + CATCH_REQUIRE(vamana_index.get_logger() != default_logger); + } + + CATCH_SECTION("With Default Logger") { + // Reset the test data setup + std::vector data = {1.0f, 2.0f}; + auto graph = svs::graphs::SimpleGraph(1, 64); + auto data_view = svs::data::SimpleDataView(data.data(), 1, dim); + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Create VamanaIndex without passing a custom logger + VamanaIndex vamana_index( + buildParams, + std::move(graph), + std::move(data_view), + entry_point, + std::move(distance_function), + std::move(threadpool) + ); + + auto default_logger = svs::logging::get(); + CATCH_REQUIRE(vamana_index.get_logger() == default_logger); + } +} \ No newline at end of file diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp new file mode 100644 index 00000000..e4dbf49b --- /dev/null +++ b/tests/svs/index/flat/flat.cpp @@ -0,0 +1,57 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "svs/index/flat/flat.h" +#include "svs/core/logging.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// spd log +#include "spdlog/sinks/callback_sink.h" + +CATCH_TEST_CASE("FlatIndex Logging Test", "[logging]") { + // Vector to store captured log messages + std::vector captured_logs; + + // Create a callback sink to capture log messages + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); + + // Set up the FlatIndex with the test logger + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + std::vector data{1.0f, 2.0f}; + auto dataView = svs::data::SimpleDataView(data.data(), 2, 1); + svs::distance::DistanceL2 dist; + auto threadpool = svs::threads::DefaultThreadPool(1); + + svs::index::flat::FlatIndex index( + std::move(dataView), dist, std::move(threadpool), test_logger + ); + + // Log a message + test_logger->info("Test FlatIndex Logging"); + + // Verify the log output + CATCH_REQUIRE(captured_logs.size() == 1); + CATCH_REQUIRE(captured_logs[0] == "Test FlatIndex Logging"); +} diff --git a/tests/svs/index/inverted/clustering.cpp b/tests/svs/index/inverted/clustering.cpp index e8a9e062..77820880 100644 --- a/tests/svs/index/inverted/clustering.cpp +++ b/tests/svs/index/inverted/clustering.cpp @@ -27,6 +27,9 @@ // stl #include +// logging +#include "spdlog/sinks/callback_sink.h" + namespace { template @@ -389,3 +392,47 @@ CATCH_TEST_CASE("Random Clustering - End to End", "[inverted][random_clustering] test_end_to_end_clustering(data, svs::DistanceIP(), 0.9f); } } + +CATCH_TEST_CASE("Clustering with Logger", "[logging]") { + // Setup logger + std::vector captured_logs; + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); // Capture all log levels + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + // Setup cluster + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto vamana_parameters = + svs::index::vamana::VamanaBuildParameters{1.2, 64, 200, 1000, 60, true}; + auto clustering_parameters = svs::index::inverted::ClusteringParameters() + .percent_centroids(svs::lib::Percent(0.1)) + .epsilon(0.05) + .max_replicas(8) + .max_cluster_size(200); + auto centroids = svs::index::inverted::randomly_select_centroids( + data.size(), + svs::lib::narrow_cast( + std::floor(data.size() * clustering_parameters.percent_centroids_.value()) + ), + clustering_parameters.seed_ + ); + auto threadpool = svs::threads::DefaultThreadPool(2); + auto index = svs::index::inverted::build_primary_index( + data, + svs::lib::as_const_span(centroids), + vamana_parameters, + svs::DistanceL2(), + std::move(threadpool) + ); + auto clustering = svs::index::inverted::cluster_with( + data, svs::lib::as_const_span(centroids), clustering_parameters, index, test_logger + ); + + // Verify the internal log messages + CATCH_REQUIRE(captured_logs[0].find("Processing batch") != std::string::npos); +} \ No newline at end of file diff --git a/tests/svs/index/inverted/memory_based.cpp b/tests/svs/index/inverted/memory_based.cpp new file mode 100644 index 00000000..e6c50925 --- /dev/null +++ b/tests/svs/index/inverted/memory_based.cpp @@ -0,0 +1,63 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "svs/index/inverted/memory_based.h" +#include "catch2/catch_test_macros.hpp" +#include "spdlog/sinks/callback_sink.h" +#include "svs-benchmark/datasets.h" +#include "svs/lib/timing.h" +#include "tests/utils/inverted_reference.h" +#include "tests/utils/test_dataset.h" +#include + +CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") { + // Vector to store captured log messages + std::vector captured_logs; + + // Create a callback sink to capture log messages + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); + + // Create a logger with the callback sink + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + // Setup index + auto distance = svs::DistanceL2(); + constexpr auto distance_type = svs::distance_type_v; + auto expected_results = test_dataset::inverted::expected_build_results( + distance_type, svsbenchmark::Uncompressed(svs::DataType::float32) + ); + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto threadpool = svs::threads::DefaultThreadPool(1); + auto invertedIndex = svs::index::inverted::auto_build( + expected_results.build_parameters_.value(), + data, + distance, + std::move(threadpool), + {}, + {}, + {}, + test_logger + ); + + // Verify the internal log messages + CATCH_REQUIRE(captured_logs[0].find("Processing batch") != std::string::npos); +} \ No newline at end of file diff --git a/tests/svs/index/vamana/dynamic_index_2.cpp b/tests/svs/index/vamana/dynamic_index_2.cpp index ea0e797c..a3acb7f0 100644 --- a/tests/svs/index/vamana/dynamic_index_2.cpp +++ b/tests/svs/index/vamana/dynamic_index_2.cpp @@ -24,7 +24,10 @@ #include "svs/misc/dynamic_helper.h" // tests +#include "spdlog/sinks/callback_sink.h" #include "tests/utils/test_dataset.h" +#include "tests/utils/utils.h" +#include "tests/utils/vamana_reference.h" // catch #include "catch2/catch_test_macros.hpp" @@ -416,3 +419,61 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") { // ID's preserved across runs. index.on_ids([&](size_t e) { CATCH_REQUIRE(reloaded.has_id(e)); }); } + +CATCH_TEST_CASE("Dynamic MutableVamanaIndex Per-Index Logging Test", "[logging]") { + // Vector to store captured log messages + std::vector captured_logs; + + // Create a callback sink to capture log messages + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); // Capture all log levels + + // Create a logger with the callback sink + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + // Setup index + std::vector data = {1.0f, 2.0f}; + std::vector initial_indices(data.size()); + std::iota(initial_indices.begin(), initial_indices.end(), 0); + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + auto data_view = svs::data::SimpleDataView(data.data(), 2, 1); + auto threadpool = svs::threads::DefaultThreadPool(1); + auto index = svs::index::vamana::MutableVamanaIndex( + buildParams, + std::move(data_view), + initial_indices, + svs::DistanceL2(), + std::move(threadpool), + test_logger + ); + + // Verify the internal log messages + CATCH_REQUIRE(captured_logs[0].find("Number of syncs:") != std::string::npos); + CATCH_REQUIRE(captured_logs[1].find("Batch Size:") != std::string::npos); +} + +CATCH_TEST_CASE("Dynamic MutableVamanaIndex Default Logger Test", "[logging]") { + // Setup index with default logger + std::vector data = {1.0f, 2.0f}; + std::vector initial_indices(data.size()); + std::iota(initial_indices.begin(), initial_indices.end(), 0); + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + auto data_view = svs::data::SimpleDataView(data.data(), 2, 1); + auto threadpool = svs::threads::DefaultThreadPool(1); + auto index = svs::index::vamana::MutableVamanaIndex( + buildParams, + std::move(data_view), + initial_indices, + svs::DistanceL2(), + std::move(threadpool) + ); + + // Verify that the default logger is used + auto default_logger = svs::logging::get(); + CATCH_REQUIRE(index.get_logger() == default_logger); +} \ No newline at end of file diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index 3e1b3fba..cd549299 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -16,6 +16,8 @@ // Header under test #include "svs/index/vamana/index.h" +#include "spdlog/sinks/callback_sink.h" +#include "svs/core/logging.h" // catch2 #include "catch2/catch_test_macros.hpp" @@ -108,3 +110,44 @@ CATCH_TEST_CASE("Vamana Index Parameters", "[index][vamana]") { CATCH_REQUIRE(svs::lib::test_self_save_load_context_free(p)); } } + +CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") { + // Vector to store captured log messages + std::vector captured_logs; + + // Create a callback sink to capture log messages + auto callback_sink = std::make_shared( + [&captured_logs](const spdlog::details::log_msg& msg) { + captured_logs.emplace_back(msg.payload.data(), msg.payload.size()); + } + ); + callback_sink->set_level(spdlog::level::trace); // Capture all log levels + + // Create a logger with the callback sink + auto test_logger = std::make_shared("test_logger", callback_sink); + test_logger->set_level(spdlog::level::trace); + + // Create some minimal data + std::vector data = {1.0f, 2.0f}; + auto graph = svs::graphs::SimpleGraph(1, 64); + auto data_view = svs::data::SimpleDataView(data.data(), 1, 1); + svs::distance::DistanceL2 distance_function; + uint32_t entry_point = 0; + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Build the VamanaIndex with the test logger + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + svs::index::vamana::VamanaIndex index( + buildParams, + std::move(graph), + std::move(data_view), + entry_point, + distance_function, + std::move(threadpool), + test_logger + ); + + // Verify the internal log messages + CATCH_REQUIRE(captured_logs[0].find("Number of syncs:") != std::string::npos); + CATCH_REQUIRE(captured_logs[1].find("Batch Size:") != std::string::npos); +} \ No newline at end of file