Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ class MutableVamanaIndex {
, build_parameters_(parameters)
, logger_{std::move(logger)} {
// Verify and set defaults directly on the input parameters
verify_and_set_default_index_parameters(
build_parameters_, distance_function, logger_
);
verify_and_set_default_index_parameters(build_parameters_, distance_function);

// Set graph again as verify function might change graph_max_degree parameter
graph_ = Graph{data_.size(), build_parameters_.graph_max_degree};
Expand All @@ -248,7 +246,13 @@ class MutableVamanaIndex {
auto prefetch_parameters =
GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_};
auto builder = VamanaBuilder(
graph_, data_, distance_, build_parameters_, threadpool_, prefetch_parameters
graph_,
data_,
distance_,
build_parameters_,
threadpool_,
prefetch_parameters,
logger_
);
builder.construct(1.0f, entry_point_[0], logging::Level::Trace, logger_);
builder.construct(
Expand Down Expand Up @@ -699,7 +703,14 @@ class MutableVamanaIndex {
auto prefetch_parameters =
GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_};
VamanaBuilder builder{
graph_, data_, distance_, parameters, threadpool_, prefetch_parameters
graph_,
data_,
distance_,
parameters,
threadpool_,
prefetch_parameters,
logger_,
logging::Level::Trace
};
builder.construct(alpha_, entry_point(), slots, logging::Level::Trace, logger_);
// Mark all added entries as valid.
Expand Down
28 changes: 5 additions & 23 deletions include/svs/index/vamana/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,15 @@ class VamanaIndex {
}
build_parameters_ = parameters;
// verify the parameters before set local var
verify_and_set_default_index_parameters(
build_parameters_, distance_function, logger
);
verify_and_set_default_index_parameters(build_parameters_, distance_function);
auto builder = VamanaBuilder(
graph_,
data_,
distance_,
build_parameters_,
threadpool_,
extensions::estimate_prefetch_parameters(data_)
extensions::estimate_prefetch_parameters(data_),
logger
);

builder.construct(1.0F, entry_point_[0], logging::Level::Trace, logger);
Expand Down Expand Up @@ -947,7 +946,7 @@ auto auto_build(

// Default graph.
auto verified_parameters = parameters;
verify_and_set_default_index_parameters(verified_parameters, distance, logger);
verify_and_set_default_index_parameters(verified_parameters, distance);
auto graph =
default_graph(data.size(), verified_parameters.graph_max_degree, graph_allocator);
using I = typename decltype(graph)::index_type;
Expand Down Expand Up @@ -1018,9 +1017,7 @@ auto auto_assemble(
/// @brief Verify parameters and set defaults if needed
template <typename Dist>
void verify_and_set_default_index_parameters(
VamanaBuildParameters& parameters,
Dist distance_function,
svs::logging::logger_ptr logger = svs::logging::get()
VamanaBuildParameters& parameters, Dist distance_function
) {
// Set default values
if (parameters.max_candidate_pool_size == svs::UNSIGNED_INTEGER_PLACEHOLDER) {
Expand Down Expand Up @@ -1068,20 +1065,5 @@ void verify_and_set_default_index_parameters(
if (parameters.prune_to > parameters.graph_max_degree) {
throw std::invalid_argument("prune_to must be <= graph_max_degree");
}

// Print all parameters
svs::logging::log(
logger,
logging::Level::Info,
"Vamana Build Parameters: alpha={}, graph_max_degree={}, "
"max_candidate_pool_size={}, prune_to={}, window_size={}, "
"use_full_search_history={}",
parameters.alpha,
parameters.graph_max_degree,
parameters.max_candidate_pool_size,
parameters.prune_to,
parameters.window_size,
parameters.use_full_search_history
);
}
} // namespace svs::index::vamana
43 changes: 28 additions & 15 deletions include/svs/index/vamana/vamana_build.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ template <typename Idx> class BackedgeBuffer {
, bucket_locks_{parameters.num_buckets_} {}

BackedgeBuffer(size_t num_elements, size_t bucket_size)
: BackedgeBuffer(BackedgeBufferParameters{
bucket_size, lib::div_round_up(num_elements, bucket_size)}) {}
: BackedgeBuffer(
BackedgeBufferParameters{
bucket_size, lib::div_round_up(num_elements, bucket_size)
}
) {}

// Add a point.
void add_edge(Idx src, Idx dst) {
Expand Down Expand Up @@ -184,7 +187,9 @@ class VamanaBuilder {
Dist distance_function,
const VamanaBuildParameters& params,
Pool& threadpool,
GreedySearchPrefetchParameters prefetch_hint = {}
GreedySearchPrefetchParameters prefetch_hint = {},
svs::logging::logger_ptr logger = svs::logging::get(),
logging::Level level = logging::Level::Debug
)
: graph_{graph}
, data_{data}
Expand All @@ -194,6 +199,20 @@ class VamanaBuilder {
, threadpool_{threadpool}
, vertex_locks_(data.size())
, backedge_buffer_{data.size(), 1000} {
// Print all parameters
svs::logging::log(
logger,
level,
"Vamana Build Parameters: alpha={}, graph_max_degree={}, "
"max_candidate_pool_size={}, prune_to={}, window_size={}, "
"use_full_search_history={}",
params.alpha,
params.graph_max_degree,
params.max_candidate_pool_size,
params.prune_to,
params.window_size,
params.use_full_search_history
);
// Check class invariants.
if (graph_.n_nodes() != data_.size()) {
throw ANNEXCEPTION(
Expand Down Expand Up @@ -296,12 +315,9 @@ class VamanaBuilder {
}
}
svs::logging::log(
logger,
logging::Level::Debug,
"Completed pass using window size {}.",
params_.window_size
logger, level, "Completed pass using window size {}.", params_.window_size
);
svs::logging::log(logger, logging::Level::Debug, "{}", timer);
svs::logging::log(logger, level, "{}", timer);
}

///
Expand All @@ -323,9 +339,7 @@ class VamanaBuilder {
update_type updates{threadpool_.size()};
auto main = timer.push_back("main");
threads::parallel_for(
threadpool_,
range,
[&](const auto& local_indices, uint64_t tid) {
threadpool_, range, [&](const auto& local_indices, uint64_t tid) {
// Thread local variables
auto& thread_local_updates = updates.at(tid);

Expand Down Expand Up @@ -476,9 +490,7 @@ class VamanaBuilder {
auto range = threads::StaticPartition{indices};
backedge_buffer_.reset();
threads::parallel_for(
threadpool_,
range,
[&](const auto& is, uint64_t SVS_UNUSED(tid)) {
threadpool_, range, [&](const auto& is, uint64_t SVS_UNUSED(tid)) {
for (auto node_id : is) {
for (auto other_id : graph_.get_node(node_id)) {
std::lock_guard lock{vertex_locks_[other_id]};
Expand Down Expand Up @@ -527,7 +539,8 @@ class VamanaBuilder {
i,
distance::compute(
general_distance, src_data, general_accessor(data_, i)
)};
)
};
};

candidates.clear();
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/vamana/index_build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ CATCH_TEST_CASE("VamanaIndex Logging Tests", "[logging]") {
CATCH_REQUIRE(
captured_logs[0].find("Vamana Build Parameters:") != std::string::npos
);
CATCH_REQUIRE(captured_levels[0] == svs::logging::Level::Info);
CATCH_REQUIRE(captured_levels[0] == svs::logging::Level::Debug);
CATCH_REQUIRE(captured_logs[1].find("Number of syncs:") != std::string::npos);
CATCH_REQUIRE(captured_levels[1] == svs::logging::Level::Trace);
CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos);
Expand Down
7 changes: 4 additions & 3 deletions tests/svs/index/inverted/clustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ void test_end_to_end_clustering(
});

auto vamana_parameters = svs::index::vamana::VamanaBuildParameters{
construction_alpha, 64, 200, 1000, 60, true};
construction_alpha, 64, 200, 1000, 60, true
};

// Build the index once and reuse it multiple times to help speed up tests.
for (size_t max_replicas : {2, 8}) {
Expand Down Expand Up @@ -447,6 +448,6 @@ CATCH_TEST_CASE("Clustering with Logger", "[logging]") {

// Verify the internal log messages
CATCH_REQUIRE(global_captured_logs.empty());
CATCH_REQUIRE(captured_logs[1].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_logs[2].find("Number of syncs") != std::string::npos);
CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos);
}
4 changes: 2 additions & 2 deletions tests/svs/index/inverted/memory_based.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {

// Verify the internal log messages
CATCH_REQUIRE(global_captured_logs.empty());
CATCH_REQUIRE(captured_logs[1].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_logs[2].find("Number of syncs") != std::string::npos);
CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos);
}
5 changes: 2 additions & 3 deletions tests/svs/index/vamana/dynamic_index_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,7 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") {
}

svs::index::vamana::VamanaBuildParameters parameters{
1.2, max_degree, 2 * max_degree, 1000, max_degree - 4, true
};
1.2, max_degree, 2 * max_degree, 1000, max_degree - 4, true};

auto tic = svs::lib::now();
auto index = svs::index::vamana::MutableVamanaIndex(
Expand All @@ -344,7 +343,7 @@ CATCH_TEST_CASE("Testing Graph Index", "[graph_index][dynamic_index]") {
CATCH_REQUIRE(captured_logs[0].find("Total / % Measured:") != std::string::npos);
CATCH_REQUIRE(captured_levels[0] == svs::logging::Level::Debug);
CATCH_REQUIRE(captured_logs[1].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_levels[1] == svs::logging::Level::Info);
CATCH_REQUIRE(captured_levels[1] == svs::logging::Level::Debug);
CATCH_REQUIRE(captured_logs[2].find("Number of syncs:") != std::string::npos);
CATCH_REQUIRE(captured_levels[2] == svs::logging::Level::Trace);
CATCH_REQUIRE(captured_logs[3].find("Batch Size:") != std::string::npos);
Expand Down
Loading