Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ wheelhouse/
tags
compile_commands.json
.python-version
.vscode

# Python related files
__pycache__/
Expand Down
97 changes: 97 additions & 0 deletions bindings/python/src/dynamic_vamana.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

// fmt
#include <fmt/format.h>

// stl
#include <span>

Expand Down Expand Up @@ -88,6 +91,75 @@ void add_build_specialization(py::class_<svs::DynamicVamana>& index) {
);
}

/////
///// Build from file (data loader)
/////

template <typename Q, typename T, typename Dist, size_t N>
svs::DynamicVamana dynamic_vamana_build_uncompressed(
const svs::index::vamana::VamanaBuildParameters& parameters,
svs::VectorDataLoader<T, N, RebindAllocator<T>> data_loader,
std::span<const size_t> ids,
svs::DistanceType distance_type,
size_t num_threads
) {
return svs::DynamicVamana::build<Q>(
parameters,
std::move(data_loader),
ids,
distance_type,
num_threads
);
}

using DynamicVamanaBuildFromFileDispatcher = svs::lib::Dispatcher<
svs::DynamicVamana,
const svs::index::vamana::VamanaBuildParameters&,
UnspecializedVectorDataLoader,
std::span<const size_t>,
svs::DistanceType,
size_t>;

DynamicVamanaBuildFromFileDispatcher dynamic_vamana_build_from_file_dispatcher() {
auto dispatcher = DynamicVamanaBuildFromFileDispatcher{};
// Register uncompressed specializations (Dynamic dimensionality only, similar to tests)
for_standard_specializations([&]<typename Q, typename T, typename D, size_t N>() {
// Only register when N is Dynamic (compile-time tag) - the pattern in static code
// registers all; here we directly register.
auto method = &dynamic_vamana_build_uncompressed<Q, T, D, N>;
dispatcher.register_target(svs::lib::dispatcher_build_docs, method);
});
return dispatcher;
}

svs::DynamicVamana dynamic_vamana_build_from_file(
const svs::index::vamana::VamanaBuildParameters& parameters,
UnspecializedVectorDataLoader data_loader,
const py_contiguous_array_t<size_t>& py_ids,
svs::DistanceType distance_type,
size_t num_threads
) {
auto ids = std::span<const size_t>(py_ids.data(), py_ids.size());
return dynamic_vamana_build_from_file_dispatcher().invoke(
parameters, std::move(data_loader), ids, distance_type, num_threads
);
}

constexpr std::string_view DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO = R"(
Construct a DynamicVamana index using a data loader, returning the index.

Args:
parameters: Build parameters controlling graph construction.
data_loader: Data loader (e.g., an VectorDataLoader instance).
ids: Vector of ids to assign to each row in the dataset; must match dataset length and contain unique values.
distance_type: The similarity function to use for this index.
num_threads: Number of threads to use for index construction. Default: 1.

Specializations compiled into the binary are listed below.

{} # (Method listing auto-generated)
)";

template <typename ElementType>
void add_points(
svs::DynamicVamana& index,
Expand Down Expand Up @@ -301,6 +373,31 @@ void wrap(py::module& m) {
// Index building.
add_build_specialization<float>(vamana);

// Build from file / data loader (dynamic docstring)
{
auto dispatcher = dynamic_vamana_build_from_file_dispatcher();
std::string dynamic;
for (size_t i = 0; i < dispatcher.size(); ++i) {
fmt::format_to(
std::back_inserter(dynamic),
R"(Method {}:\n - data_loader: {}\n - distance: {}\n)",
i,
dispatcher.description(i, 1),
dispatcher.description(i, 3)
);
}
vamana.def_static(
"build",
&dynamic_vamana_build_from_file,
py::arg("parameters"),
py::arg("data_loader"),
py::arg("ids"),
py::arg("distance_type"),
py::arg("num_threads") = 1,
fmt::format(DYNAMIC_VAMANA_BUILD_FROM_FILE_DOCSTRING_PROTO, dynamic).c_str()
);
}

// Index modification.
add_points_specialization<float>(vamana);

Expand Down
40 changes: 40 additions & 0 deletions bindings/python/tests/test_dynamic_vamana.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

# unit under test
import svs
import numpy as np

# stdlib
import unittest
import os
from tempfile import TemporaryDirectory

# helpers
from .common import test_data_svs, test_data_dims, test_number_of_vectors, test_queries, test_groundtruth_l2
from .dynamic import ReferenceDataset

class DynamicVamanaTester(unittest.TestCase):
Expand Down Expand Up @@ -162,3 +164,41 @@ def test_loop(self):
)
consolidate_count = 0

def test_build_from_loader(self):
"""Test building DynamicVamana using a VectorDataLoader and explicit IDs."""

loader = svs.VectorDataLoader(test_data_svs, svs.DataType.float32, dims = test_data_dims)

# Sequential IDs
ids = np.arange(test_number_of_vectors, dtype = np.uint64)

params = svs.VamanaBuildParameters(
graph_max_degree = 64,
window_size = 128,
alpha = 1.2,
)

index = svs.DynamicVamana.build(
params,
loader,
ids,
svs.DistanceType.L2,
num_threads = 2,
)

# Basic invariants
self.assertEqual(index.size, test_number_of_vectors)
self.assertEqual(index.dimensions, test_data_dims)
self.assertTrue(index.has_id(0))
self.assertTrue(index.has_id(test_number_of_vectors - 1))

queries = svs.read_vecs(test_queries)
groundtruth = svs.read_vecs(test_groundtruth_l2)
k = 10
index.search_window_size = 20
I, D = index.search(queries, k)
self.assertEqual(I.shape[1], k)
recall = svs.k_recall_at(groundtruth, I, k, k)
# Recall in plausible range
self.assertTrue(0.5 < recall <= 1.0)

51 changes: 42 additions & 9 deletions include/svs/orchestrators/dynamic_vamana.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "svs/orchestrators/manager.h"
#include "svs/orchestrators/vamana.h"

// stdlib
#include <type_traits>

namespace svs {

///
Expand Down Expand Up @@ -258,25 +261,55 @@ class DynamicVamana : public manager::IndexManager<DynamicVamanaInterface> {
}

// Building
///
/// @brief Construct a DynamicVamana index from a data loader or dataset.
///
/// @tparam QueryTypes The set of query element types supported by the resulting index.
/// @tparam DataLoader A data loader or dataset type.
/// @tparam Distance Distance functor or ``svs::DistanceType`` enum.
/// @tparam ThreadPoolProto Thread pool type or size_t).
///
/// @param parameters Build parameters controlling graph construction.
/// @param data_loader Loader (or dataset) from which to obtain the data.
/// @param ids External IDs to assign to each row; must be unique and have length ``data.size()``.
/// @param distance Distance functor or enum.
/// @param threadpool_proto Thread pool or number of threads to use.
///
template <
manager::QueryTypeDefinition QueryTypes,
data::ImmutableMemoryDataset Data,
typename DataLoader,
typename Distance,
typename ThreadPoolProto>
static DynamicVamana build(
const index::vamana::VamanaBuildParameters& parameters,
Data data,
DataLoader&& data_loader,
std::span<const size_t> ids,
Distance distance,
ThreadPoolProto threadpool_proto
) {
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
parameters,
std::move(data),
ids,
std::move(distance),
threads::as_threadpool(std::move(threadpool_proto))
);
auto threadpool = threads::as_threadpool(std::move(threadpool_proto));
auto data = svs::detail::dispatch_load(std::forward<DataLoader>(data_loader), threadpool);
// If given a DistanceType enum, dispatch to a concrete distance functor first.
if constexpr (std::is_same_v<std::decay_t<Distance>, DistanceType>) {
auto dispatcher = DistanceDispatcher(distance);
return dispatcher([&](auto distance_function) {
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
parameters,
std::move(data),
ids,
std::move(distance_function),
std::move(threadpool)
);
});
} else {
return make_dynamic_vamana<manager::as_typelist<QueryTypes>>(
parameters,
std::move(data),
ids,
std::move(distance),
std::move(threadpool)
);
}
}

// Assembly
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ SET(INTEGRATION_TESTS
${TEST_DIR}/svs/index/vamana/dynamic_index_2.cpp
# Higher level constructs
${TEST_DIR}/svs/orchestrators/vamana.cpp
${TEST_DIR}/svs/orchestrators/dynamic_vamana.cpp
# Integration Tests
${TEST_DIR}/integration/exhaustive.cpp
${TEST_DIR}/integration/vamana/index_search.cpp
Expand Down
125 changes: 125 additions & 0 deletions tests/svs/orchestrators/dynamic_vamana.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Orchestrator under test
#include "svs/orchestrators/dynamic_vamana.h"

// Core helpers
#include "svs/core/recall.h"
#include "svs/core/data/simple.h"

// Distance dispatcher
#include "svs/core/distance.h"

// Test dataset utilities
#include "tests/utils/test_dataset.h"
#include "tests/utils/vamana_reference.h"
#include "tests/utils/utils.h"

// Catch2
#include "catch2/catch_test_macros.hpp"
#include "catch2/catch_approx.hpp"

// STL
#include <vector>
#include <numeric>

namespace {

template <typename DataLoaderT, typename DistanceT>
void test_build(
DataLoaderT&& data_loader,
DistanceT distance = DistanceT()
) {
auto expected_result = test_dataset::vamana::expected_build_results(
distance, svsbenchmark::Uncompressed(svs::DataType::float32)
);
auto build_params = expected_result.build_parameters_.value();
auto queries = svs::data::SimpleData<float>::load(test_dataset::query_file());
auto groundtruth = test_dataset::load_groundtruth(distance);

// Prepare IDs (0 .. N-1)
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
const size_t n = data.size();
std::vector<size_t> ids(n);
std::iota(ids.begin(), ids.end(), 0);

size_t num_threads = 2;
svs::DynamicVamana index = svs::DynamicVamana::build<float>(
build_params,
std::forward<DataLoaderT>(data_loader),
ids,
distance,
num_threads
);

// Basic invariants
CATCH_REQUIRE(index.get_alpha() == Catch::Approx(build_params.alpha));
CATCH_REQUIRE(index.get_construction_window_size() == build_params.window_size);
CATCH_REQUIRE(index.get_prune_to() == build_params.prune_to);
CATCH_REQUIRE(index.get_graph_max_degree() == build_params.graph_max_degree);
CATCH_REQUIRE(index.get_num_threads() == num_threads);

// ID checks (spot sample)
CATCH_REQUIRE(index.has_id(0));
CATCH_REQUIRE(index.has_id(n / 2));
CATCH_REQUIRE(index.has_id(n - 1));

const double epsilon = 0.01; // allow small deviation
for (const auto& expected : expected_result.config_and_recall_) {
auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_);
auto these_groundtruth =
test_dataset::get_test_set(groundtruth, expected.num_queries_);
index.set_search_parameters(expected.search_parameters_);
auto results = index.search(these_queries, expected.num_neighbors_);
double recall = svs::k_recall_at_n(
these_groundtruth, results, expected.num_neighbors_, expected.recall_k_
);
CATCH_REQUIRE(recall > expected.recall_ - epsilon);
CATCH_REQUIRE(recall < expected.recall_ + epsilon);
}
}

} // namespace

CATCH_TEST_CASE("DynamicVamana Build", "[managers][dynamic_vamana][build]") {
for (auto distance_enum : test_dataset::vamana::available_build_distances()) {
// SimpleData and distance functor.
{
std::string section_name = std::string("SimpleData ") + std::string(svs::name(distance_enum));
CATCH_SECTION(section_name) {
svs::DistanceDispatcher dispatcher(distance_enum);
dispatcher([&](auto distance_functor) {
test_build(
svs::data::SimpleData<float>::load(test_dataset::data_svs_file()),
distance_functor
);
});
}
}

// VectorDataLoader and distance enum.
{
std::string section_name = std::string("VectorDataLoader ") + std::string(svs::name(distance_enum));
CATCH_SECTION(section_name) {
test_build(
svs::VectorDataLoader<float>(test_dataset::data_svs_file()),
distance_enum
);
}
}
}
}
Loading
Loading