Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 110 additions & 19 deletions include/svs/index/vamana/multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -563,16 +563,8 @@ class MultiMutableVamanaIndex {
constexpr std::string_view name() const { return "multi dynamic vamana index"; }

static constexpr lib::Version save_version = lib::Version(0, 0, 0);
void save(
const std::filesystem::path& config_directory,
const std::filesystem::path& graph_directory,
const std::filesystem::path& data_directory
) {
// Post-consolidation, all entries should be "valid".
// Therefore, we don't need to save the slot metadata.
consolidate();
compact();

auto get_labels() const {
// Since data is in order of external ids,
// convert a map of external ids to label types into a sorted vector of labels based
// on external ids.
Expand All @@ -592,6 +584,34 @@ class MultiMutableVamanaIndex {
[](const auto& ext_lab) { return ext_lab.second; }
);

return labels;
}

VamanaIndexParameters get_parameters() const {
return {
index_->entry_point_.front(),
{get_alpha(),
max_degree(),
get_construction_window_size(),
get_max_candidates(),
get_prune_to(),
get_full_search_history()},
get_search_parameters()};
}

void save(
const std::filesystem::path& config_directory,
const std::filesystem::path& graph_directory,
const std::filesystem::path& data_directory
) {
// Post-consolidation, all entries should be "valid".
// Therefore, we don't need to save the slot metadata.
consolidate();
compact();

auto labels = get_labels();
size_t num_labels = labels.size();

// Save auxiliary data structures.
lib::save_to_disk(
lib::SaveOverride([&](const lib::SaveContext& ctx) {
Expand All @@ -601,16 +621,7 @@ class MultiMutableVamanaIndex {
lib::write_binary(stream, labels);

// Save the construction parameters.
auto parameters = VamanaIndexParameters{
index_->entry_point_.front(),
{get_alpha(),
max_degree(),
get_construction_window_size(),
get_max_candidates(),
get_prune_to(),
get_full_search_history()},
get_search_parameters()};

auto parameters = get_parameters();
return lib::SaveTable(
"multi_vamana_dynamic_auxiliary_parameters",
save_version,
Expand All @@ -628,6 +639,32 @@ class MultiMutableVamanaIndex {
// Graph
lib::save_to_disk(index_->graph_, graph_directory);
}

void save(std::ostream& os) {
consolidate();
compact();

auto labels = get_labels();
size_t num_labels = labels.size();

lib::begin_serialization(os);

auto parameters = get_parameters();
auto save_table = lib::SaveTable(
"multi_vamana_dynamic_auxiliary_parameters",
save_version,
{{"name", lib::save(name())},
{"parameters", lib::save(parameters)},
{"num_labels", lib::save(num_labels)}}
);
lib::save_to_stream(save_table, os);
lib::write_binary(os, labels);

// Save the dataset.
lib::save_to_stream(index_->data_, os);
// Save the graph.
lib::save_to_stream(index_->graph_, os);
}
};

///// Deduction Guides.
Expand Down Expand Up @@ -789,4 +826,58 @@ auto auto_multi_dynamic_assemble(
}
}

template <
typename LazyGraphLoader,
typename LazyDataLoader,
typename Distance,
typename ThreadPoolProto>
auto auto_multi_dynamic_assemble(
std::istream& is,
LazyGraphLoader graph_loader,
LazyDataLoader data_loader,
Distance distance,
ThreadPoolProto threadpool_proto,
svs::logging::logger_ptr logger = svs::logging::get()
) {
using label_type = size_t;

auto table = lib::detail::read_metadata(is);

auto parameters = lib::load<VamanaIndexParameters>(
table.template cast<toml::table>().at("parameters").template cast<toml::table>()
);

auto num_labels =
lib::load<size_t>(table.template cast<toml::table>().at("num_labels"));

// Read labels binary data directly from the stream.
std::vector<label_type> labels(num_labels);
lib::read_binary(is, labels);

auto data = data_loader();
auto graph = graph_loader();

auto datasize = data.size();
auto graphsize = graph.n_nodes();
if (datasize != graphsize) {
throw ANNEXCEPTION(
"Reloaded data has {} nodes while the graph has {} nodes!", datasize, graphsize
);
}

if (labels.size() != datasize) {
throw ANNEXCEPTION("Labels has {} IDs but should have {}", labels.size(), datasize);
}

auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
return MultiMutableVamanaIndex{
parameters,
std::move(data),
std::move(graph),
std::move(distance),
labels,
std::move(threadpool),
std::move(logger)};
}

} // namespace svs::index::vamana
145 changes: 145 additions & 0 deletions tests/svs/index/vamana/multi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "catch2/catch_test_macros.hpp"

// stl
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -304,3 +305,147 @@ CATCH_TEMPLATE_TEST_CASE(
CATCH_REQUIRE(test_recall_2 > test_recall - epsilon);
}
}

CATCH_TEST_CASE(
"MultiMutableVamana Index Save and Load", "[index][vamana][multi][saveload]"
) {
using Eltype = float;
using Distance = svs::DistanceL2;
const size_t N = 128;
const size_t num_threads = 4;
const size_t num_neighbors = 10;
const size_t max_degree = 64;

const auto data = svs::data::SimpleData<Eltype, N>::load(test_dataset::data_svs_file());
const auto num_points = data.size();
const auto queries = test_dataset::queries();
const auto groundtruth = test_dataset::load_groundtruth(svs::distance_type_v<Distance>);

const svs::index::vamana::VamanaBuildParameters build_parameters{
1.2, max_degree, 10, 20, 10, true};

const auto search_parameters = svs::index::vamana::VamanaSearchParameters();

const float epsilon = 0.05f;

std::vector<size_t> test_indices(num_points);
const size_t per_label = 2;
const auto num_labels = num_points / per_label;
for (auto& i : test_indices) {
i = std::rand() % num_labels;
}

auto index = svs::index::vamana::MultiMutableVamanaIndex(
build_parameters, data, test_indices, Distance(), num_threads
);
auto results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
index.search(results.view(), queries.view(), search_parameters);

CATCH_SECTION("Load MultiMutableVamana Index being serialized natively to stream") {
std::stringstream stream;
index.save(stream);
{
auto deserializer = svs::lib::detail::Deserializer::build(stream);
CATCH_REQUIRE(deserializer.is_native());

using Data_t = svs::data::SimpleData<Eltype, N>;
using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>;

auto loaded = svs::index::vamana::auto_multi_dynamic_assemble(
stream,
[&]() -> GraphType { return GraphType::load(stream); },
[&]() -> Data_t { return svs::lib::load_from_stream<Data_t>(stream); },
Distance(),
num_threads
);

CATCH_REQUIRE(loaded.size() == index.size());
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());
CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha());
CATCH_REQUIRE(
loaded.get_construction_window_size() ==
index.get_construction_window_size()
);
CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates());
CATCH_REQUIRE(loaded.max_degree() == index.max_degree());
CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to());
CATCH_REQUIRE(
loaded.get_full_search_history() == index.get_full_search_history()
);
CATCH_REQUIRE(loaded.view_data() == index.view_data());

auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
loaded.search(loaded_results.view(), queries.view(), search_parameters);
for (size_t i = 0; i < results.n_queries(); ++i) {
for (size_t j = 0; j < results.n_neighbors(); ++j) {
CATCH_REQUIRE(
results.indices().at(i, j) == loaded_results.indices().at(i, j)
);
}
}

auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results);
auto test_recall = svs::k_recall_at_n(groundtruth, results);
CATCH_REQUIRE(loaded_recall > test_recall - epsilon);
}
}

CATCH_SECTION("Load MultiMutableVamana Index being serialized with intermediate files"
) {
std::stringstream stream;
svs::lib::UniqueTempDirectory tempdir{"svs_multivamana_save"};
const auto config_dir = tempdir.get() / "config";
const auto graph_dir = tempdir.get() / "graph";
const auto data_dir = tempdir.get() / "data";
std::filesystem::create_directories(config_dir);
std::filesystem::create_directories(graph_dir);
std::filesystem::create_directories(data_dir);
index.save(config_dir, graph_dir, data_dir);
svs::lib::DirectoryArchiver::pack(tempdir, stream);
{
using Data_t = svs::data::SimpleData<Eltype, N>;
using GraphType = svs::graphs::SimpleBlockedGraph<uint32_t>;

auto deserializer = svs::lib::detail::Deserializer::build(stream);
CATCH_REQUIRE(!deserializer.is_native());
svs::lib::DirectoryArchiver::unpack(stream, tempdir, deserializer.magic());

auto loaded = svs::index::vamana::auto_multi_dynamic_assemble(
config_dir,
GraphType::load(graph_dir),
Data_t::load(data_dir),
Distance(),
num_threads
);

CATCH_REQUIRE(loaded.size() == index.size());
CATCH_REQUIRE(loaded.dimensions() == index.dimensions());
CATCH_REQUIRE(loaded.get_alpha() == index.get_alpha());
CATCH_REQUIRE(
loaded.get_construction_window_size() ==
index.get_construction_window_size()
);
CATCH_REQUIRE(loaded.get_max_candidates() == index.get_max_candidates());
CATCH_REQUIRE(loaded.max_degree() == index.max_degree());
CATCH_REQUIRE(loaded.get_prune_to() == index.get_prune_to());
CATCH_REQUIRE(
loaded.get_full_search_history() == index.get_full_search_history()
);
CATCH_REQUIRE(loaded.view_data() == index.view_data());

auto loaded_results = svs::QueryResult<size_t>(queries.size(), num_neighbors);
loaded.search(loaded_results.view(), queries.view(), search_parameters);
for (size_t i = 0; i < results.n_queries(); ++i) {
for (size_t j = 0; j < results.n_neighbors(); ++j) {
CATCH_REQUIRE(
results.indices().at(i, j) == loaded_results.indices().at(i, j)
);
}
}

auto loaded_recall = svs::k_recall_at_n(groundtruth, loaded_results);
auto test_recall = svs::k_recall_at_n(groundtruth, results);
CATCH_REQUIRE(loaded_recall > test_recall - epsilon);
}
}
}
Loading