Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions include/svs/index/inverted/clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,36 @@ template <std::integral I> class Clustering {
// Saving and Loading.
static constexpr lib::Version save_version{0, 0, 0};
static constexpr std::string_view serialization_schema = "clustering";

lib::SaveTable metadata() const {
return lib::SaveTable(
serialization_schema,
save_version,
{{"integer_type", lib::save(datatype_v<I>)},
{"num_clusters", lib::save(size())}}
);
}

void save(std::ostream& os) const {
for (const auto& [id, cluster] : *this) {
cluster.serialize(os);
}
}

static Clustering<I>
load(const lib::ContextFreeLoadTable& table, std::istream& stream) {
auto saved_integer_type = lib::load_at<DataType>(table, "integer_type");
if (saved_integer_type != datatype_v<I>) {
throw ANNEXCEPTION("Clustering was saved using {} but we're trying to reload it using {}!", saved_integer_type, datatype_v<I>);
}
auto num_clusters = lib::load_at<size_t>(table, "num_clusters");
auto clustering = Clustering<I>();
for (size_t i = 0; i < num_clusters; ++i) {
clustering.insert(Cluster<I>::deserialize(stream));
}
return clustering;
}

lib::SaveTable save(const lib::SaveContext& ctx) const {
// Serialize all clusters into an auxiliary file.
auto fullpath = ctx.generate_name("clustering", "bin");
Expand All @@ -582,48 +612,28 @@ template <std::integral I> class Clustering {
}
}

return lib::SaveTable(
serialization_schema,
save_version,
{{"filepath", lib::save(fullpath.filename())},
SVS_LIST_SAVE(filesize),
{"integer_type", lib::save(datatype_v<I>)},
{"num_clusters", lib::save(size())}}
);
auto table = metadata();
table.insert("filepath", lib::save(fullpath.filename()));
table.insert("filesize", lib::save(filesize));
return table;

return table;
}

static Clustering<I> load(const lib::LoadTable& table) {
// Ensure we have the correct integer type when decoding.
auto saved_integer_type = lib::load_at<DataType>(table, "integer_type");
if (saved_integer_type != datatype_v<I>) {
auto type = datatype_v<I>;
auto expected_filesize = lib::load_at<size_t>(table, "filesize");
auto file = table.resolve_at("filepath");
size_t actual_filesize = std::filesystem::file_size(file);
if (actual_filesize != expected_filesize) {
throw ANNEXCEPTION(
"Clustering was saved using {} but we're trying to reload it using {}!",
saved_integer_type,
type
"Expected cluster file size to be {}. Instead, it is {}!",
actual_filesize,
expected_filesize
);
}

auto num_clusters = lib::load_at<size_t>(table, "num_clusters");
auto expected_filesize = lib::load_at<size_t>(table, "filesize");
auto clustering = Clustering<I>();
{
auto file = table.resolve_at("filepath");
size_t actual_filesize = std::filesystem::file_size(file);
if (actual_filesize != expected_filesize) {
throw ANNEXCEPTION(
"Expected cluster file size to be {}. Instead, it is {}!",
actual_filesize,
expected_filesize
);
}

auto io = lib::open_read(file);
for (size_t i = 0; i < num_clusters; ++i) {
clustering.insert(Cluster<I>::deserialize(io));
}
}
return clustering;
auto io = lib::open_read(file);
return load(table, io);
}

private:
Expand Down
44 changes: 44 additions & 0 deletions include/svs/index/inverted/memory_based.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ template <typename Index, typename Cluster> class InvertedIndex {
index_.save(index_config, graph, data);
}

void save_primary_index(std::ostream& os) const { index_.save(os); }

///// Accessors
/// @brief Getter method for logger
svs::logging::logger_ptr get_logger() const { return logger_; }
Expand Down Expand Up @@ -655,4 +657,46 @@ auto assemble_from_clustering(
);
}

template <
typename DataProto,
typename Distance,
StorageStrategy Strategy,
typename ThreadPoolProto>
auto assemble_from_clustering(
std::istream& is,
DataProto data_proto,
Distance distance,
Strategy strategy,
ThreadPoolProto threadpool_proto,
svs::logging::logger_ptr logger = svs::logging::get()
) {
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
auto original = svs::detail::dispatch_load(std::move(data_proto), threadpool);
auto clustering = lib::load_from_stream<Clustering<uint32_t>>(is);
auto ids = clustering.sorted_centroids();

// skip magic
svs::lib::detail::Deserializer::build(is);
auto index = index::vamana::auto_assemble(
is,
lib::Lazy([&]() { return GraphLoader<uint32_t>::return_type::load(is); }),
lib::Lazy([&]() {
using T = typename std::decay_t<decltype(original)>::element_type;
constexpr size_t Ext = std::decay_t<decltype(original)>::extent;
return lib::load_from_stream<data::SimpleData<T, Ext>>(is);
}),
distance,
1,
logger
);

return InvertedIndex(
std::move(index),
strategy(original, clustering, HugepageAllocator<std::byte>()),
std::move(ids),
std::move(threadpool),
std::move(logger)
);
}

} // namespace svs::index::inverted
31 changes: 31 additions & 0 deletions include/svs/orchestrators/inverted.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class InvertedInterface {
const std::filesystem::path& primary_data,
const std::filesystem::path& primary_graph
) = 0;

///// Saving
virtual void save_primary_index(std::ostream& os) = 0;
};

template <lib::TypeList QueryTypes, typename Impl, typename IFace = InvertedInterface>
Expand Down Expand Up @@ -72,6 +75,8 @@ class InvertedImpl : public manager::ManagerImpl<QueryTypes, Impl, IFace> {
) override {
impl().save_primary_index(primary_config, primary_data, primary_graph);
}

void save_primary_index(std::ostream& os) override { impl().save_primary_index(os); }
};

/////
Expand Down Expand Up @@ -106,6 +111,8 @@ class Inverted : public manager::IndexManager<InvertedInterface> {
impl_->save_primary_index(primary_config, primary_data, primary_graph);
}

void save_primary_index(std::ostream& os) { impl_->save_primary_index(os); }

///// Building
template <
manager::QueryTypeDefinition QueryTypes,
Expand Down Expand Up @@ -168,6 +175,30 @@ class Inverted : public manager::IndexManager<InvertedInterface> {
std::move(threadpool_proto)
)};
}
template <
manager::QueryTypeDefinition QueryTypes,
typename DataProto,
typename Distance,
typename ThreadPoolProto,
typename StorageStrategy = index::inverted::SparseStrategy>
static Inverted assemble_from_clustering(
std::istream& is,
DataProto data_proto,
Distance distance,
ThreadPoolProto threadpool_proto,
StorageStrategy strategy = {}
) {
return Inverted{
std::in_place,
manager::as_typelist<QueryTypes>{},
index::inverted::assemble_from_clustering(
is,
std::move(data_proto),
std::move(distance),
std::move(strategy),
std::move(threadpool_proto)
)};
}
};

} // namespace svs
75 changes: 75 additions & 0 deletions tests/svs/index/inverted/memory_based.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
#include "spdlog/sinks/callback_sink.h"
#include "svs-benchmark/datasets.h"
#include "svs/lib/timing.h"
#include "svs/orchestrators/inverted.h"
#include "tests/utils/inverted_reference.h"
#include "tests/utils/test_dataset.h"
#include <filesystem>
#include <sstream>

CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
// Vector to store captured log messages
Expand Down Expand Up @@ -73,3 +75,76 @@ CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
CATCH_REQUIRE(captured_logs[0].find("Vamana Build Parameters:") != std::string::npos);
CATCH_REQUIRE(captured_logs[1].find("Number of syncs") != std::string::npos);
}

namespace {
constexpr size_t NUM_NEIGHBORS = 10;

template <typename Strategy> void test_stream_save_load(Strategy strategy) {
auto distance = svs::DistanceL2();
constexpr auto distance_type = svs::distance_type_v<svs::DistanceL2>;
auto expected_results = test_dataset::inverted::expected_build_results(
distance_type, svsbenchmark::Uncompressed(svs::DataType::float32)
);
auto build_parameters = expected_results.build_parameters_.value();

// Capture the clustering during build.
svs::index::inverted::Clustering<uint32_t> clustering;
auto clustering_op = [&](const auto& c) { clustering = c; };

svs::Inverted index = svs::Inverted::build<float>(
build_parameters,
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
distance,
2,
strategy,
svs::index::inverted::PickRandomly{},
clustering_op
);

auto queries = svs::data::SimpleData<float>::load(test_dataset::query_file());
auto parameters = index.get_search_parameters();
auto results = index.search(queries, NUM_NEIGHBORS);

// Serialize to stream.
std::stringstream ss;
svs::lib::save_to_stream(clustering, ss);
index.save_primary_index(ss);

// Load from stream.
svs::Inverted loaded = svs::Inverted::assemble_from_clustering<svs::lib::Types<float>>(
ss,
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
distance,
2,
strategy
);
loaded.set_search_parameters(parameters);

// Compare basic properties.
CATCH_REQUIRE(loaded.size() == index.size());
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());

// Compare search results element-wise.
auto loaded_results = loaded.search(queries, NUM_NEIGHBORS);
CATCH_REQUIRE(loaded_results.n_queries() == results.n_queries());
CATCH_REQUIRE(loaded_results.n_neighbors() == results.n_neighbors());
for (size_t q = 0; q < results.n_queries(); ++q) {
for (size_t i = 0; i < NUM_NEIGHBORS; ++i) {
CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i));
CATCH_REQUIRE(
loaded_results.distance(q, i) ==
Catch::Approx(results.distance(q, i)).epsilon(1e-5)
);
}
}
}
} // namespace

CATCH_TEST_CASE("InvertedIndex Save and Load", "[saveload][inverted][index]") {
CATCH_SECTION("SparseStrategy") {
test_stream_save_load(svs::index::inverted::SparseStrategy());
}
CATCH_SECTION("DenseStrategy") {
test_stream_save_load(svs::index::inverted::DenseStrategy());
}
}
Loading