Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion include/svs/index/ivf/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ class IVFIndex {
, cluster_{std::move(cluster)}
, cluster0_{cluster_.view_cluster(0)}
, distance_{std::move(distance_function)}
, inter_query_threadpool_{threads::as_threadpool(std::move(threadpool_proto))}
, inter_query_threadpool_{make_inter_query_threadpool(
std::move(threadpool_proto), centroids_.size(), logger
)}
Comment thread
rfsaliev marked this conversation as resolved.
, intra_query_thread_count_{intra_query_thread_count}
, logger_{std::move(logger)} {
validate_thread_configuration();
Expand Down Expand Up @@ -572,6 +574,53 @@ class IVFIndex {

///// Initialization Methods /////

static auto make_inter_query_threadpool(
ThreadPoolProto proto, size_t num_centroids, svs::logging::logger_ptr& logger
) -> decltype(threads::as_threadpool(std::move(proto))) {
if constexpr (std::is_same_v<ThreadPoolProto, size_t>) {
// Specialization for size_t thread pool prototype to allow automatic resizing
// and logging of adjustments.
if (proto > num_centroids) {
svs::logging::warn(
logger,
"Provided thread pool has {} threads, but there are only {} centroids. "
"Reducing thread pool size to match number of centroids.",
proto,
num_centroids
);
proto = num_centroids;
}
} else if constexpr (requires { proto.resize(num_centroids); }) {
// Specialization for thread pool prototypes that support resizing.
if (proto.size() > num_centroids) {
svs::logging::warn(
logger,
"Provided thread pool has {} threads, but there are only {} centroids. "
"Reducing thread pool size to match number of centroids.",
proto.size(),
num_centroids
);
proto.resize(num_centroids);
}
} else {
// Generic inter-query thread pool adjustment which just validates the thread
// count against the number of centroids.
if (proto.size() > num_centroids) {
svs::logging::error(
logger,
"Provided thread pool has {} threads, but there are only {} centroids. "
"This configuration is not supported.",
proto.size(),
num_centroids
);
throw std::invalid_argument(
"Number of inter-query threads cannot exceed number of centroids"
);
}
}
return threads::as_threadpool(std::move(proto));
}

void validate_thread_configuration() {
if (intra_query_thread_count_ < 1) {
throw std::invalid_argument("Intra-query thread count must be at least 1");
Expand Down
2 changes: 1 addition & 1 deletion include/svs/index/ivf/sorted_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ template <typename Idx, typename Cmp = std::less<>> class SortedBuffer {
/// @brief Return ``true`` if a neighbor with the given distance can be skipped.
///
bool can_skip(float distance) const {
return compare_(back().distance(), distance) && full();
return full() && compare_(back().distance(), distance);
}

///
Expand Down
125 changes: 125 additions & 0 deletions tests/svs/index/ivf/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// svs
#include "svs/core/data.h"
#include "svs/core/distance.h"
#include "svs/core/logging.h"
#include "svs/index/ivf/clustering.h"
#include "svs/index/ivf/hierarchical_kmeans.h"
#include "svs/lib/saveload.h"
Expand All @@ -36,6 +37,9 @@
#include <numeric>
#include <sstream>

// third-party
#include <spdlog/sinks/callback_sink.h>

CATCH_TEST_CASE("IVF Index Single Search", "[ivf][index][single_search]") {
namespace ivf = svs::index::ivf;

Expand Down Expand Up @@ -417,3 +421,124 @@ CATCH_TEST_CASE("IVF Index Save and Load", "[ivf][index][saveload]") {
svs_test::cleanup_temp_directory();
}
}

CATCH_TEST_CASE("IVF Index Inter-Query Thread Count Boundaries", "[ivf][index][threads]") {
namespace ivf = svs::index::ivf;

auto make_test_logger = [](std::vector<std::string>& captured_logs,
std::vector<svs::logging::Level>& captured_levels) {
auto callback_sink = std::make_shared<spdlog::sinks::callback_sink_mt>(
[&captured_logs, &captured_levels](const spdlog::details::log_msg& msg) {
captured_logs.emplace_back(msg.payload.data(), msg.payload.size());
captured_levels.push_back(svs::logging::detail::from_spdlog(msg.level));
}
);
callback_sink->set_level(spdlog::level::trace);
auto logger =
std::make_shared<spdlog::logger>("ivf_threads_test_logger", callback_sink);
logger->set_level(spdlog::level::trace);
return logger;
};

auto build_components = []() {
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
auto distance = svs::distance::DistanceL2();
auto build_params = ivf::IVFBuildParameters(2, 5, false);
auto build_threadpool = svs::threads::SequentialThreadPool();

auto clustering = ivf::build_clustering<float>(
build_params, data, distance, build_threadpool, false
);

auto centroids = clustering.centroids();
using Idx = uint32_t;
auto cluster = ivf::DenseClusteredDataset<decltype(centroids), Idx, decltype(data)>(
clustering, data, build_threadpool, svs::lib::Allocator<std::byte>()
);

return std::make_tuple(
std::move(centroids), std::move(cluster), std::move(distance)
);
};

CATCH_SECTION("size_t thread prototype is clamped and warns") {
auto [centroids, cluster, distance] = build_components();
CATCH_REQUIRE(centroids.size() == 2);

std::vector<std::string> logs;
std::vector<svs::logging::Level> levels;
auto logger = make_test_logger(logs, levels);

using IndexType = ivf::
IVFIndex<decltype(centroids), decltype(cluster), decltype(distance), size_t>;

IndexType index(
std::move(centroids), std::move(cluster), distance, size_t{4}, 1, logger
);

CATCH_REQUIRE(index.get_num_threads() == 2);
CATCH_REQUIRE(
std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) !=
levels.end()
);
}

CATCH_SECTION("resizable thread prototype is clamped and warns") {
auto [centroids, cluster, distance] = build_components();
CATCH_REQUIRE(centroids.size() == 2);

std::vector<std::string> logs;
std::vector<svs::logging::Level> levels;
auto logger = make_test_logger(logs, levels);

auto threadpool_proto = svs::threads::NativeThreadPool(4);
using IndexType = ivf::IVFIndex<
decltype(centroids),
decltype(cluster),
decltype(distance),
decltype(threadpool_proto)>;

IndexType index(
std::move(centroids),
std::move(cluster),
distance,
std::move(threadpool_proto),
1,
logger
);

CATCH_REQUIRE(index.get_num_threads() == 2);
CATCH_REQUIRE(
std::find(levels.begin(), levels.end(), svs::logging::Level::Warn) !=
levels.end()
);
}

CATCH_SECTION("non-resizable thread prototype throws") {
auto [centroids, cluster, distance] = build_components();
CATCH_REQUIRE(centroids.size() == 2);

std::vector<std::string> logs;
std::vector<svs::logging::Level> levels;
auto logger = make_test_logger(logs, levels);

auto threadpool_proto = svs::threads::QueueThreadPoolWrapper(4);
using IndexType = ivf::IVFIndex<
decltype(centroids),
decltype(cluster),
decltype(distance),
decltype(threadpool_proto)>;

CATCH_REQUIRE_THROWS_AS(
IndexType(
std::move(centroids),
std::move(cluster),
distance,
std::move(threadpool_proto),
1,
logger
),
std::invalid_argument
);
}
}
Loading